From db87a2b2c6c82482f50c7139af872b2e1d08d286 Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Mon, 27 Aug 2018 13:55:52 -0700 Subject: [PATCH 01/21] add Generic MXNetHandle trait and MXNetHandlePhantomRef class that will be used by all MXNetObjects --- .../scala/org/apache/mxnet/MXNetHandle.scala | 101 ++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 scala-package/core/src/main/scala/org/apache/mxnet/MXNetHandle.scala diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/MXNetHandle.scala b/scala-package/core/src/main/scala/org/apache/mxnet/MXNetHandle.scala new file mode 100644 index 000000000000..39046b584234 --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/MXNetHandle.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet + +import org.apache.mxnet.Base.CPtrAddress + +import scala.ref.{PhantomReference, ReferenceQueue} +import java.util.concurrent.ConcurrentHashMap +import org.apache.mxnet.Base.checkCall + +/** + * Should be generic to All MXNet Objects + * Should call the DeAlloc automatically + * Should do it only if not disposed. ie., dispose removes from the refQ + */ + +private[mxnet] trait MXNetHandle extends AutoCloseable { + + val nativeAddress: CPtrAddress + + val bytesAllocated: Long + + var isDisposed: Boolean = false + + val deAllocFn = (mxFreeHandleAddress: CPtrAddress => Int) + + def register(referent: MXNetHandle): Unit = { + MXNetHandlePhantomRef.register(this, deAllocFn) + } + + def deRegister(referent: MXNetHandle): Unit = { + MXNetHandlePhantomRef.deRegister(referent) + } + + /* call {@link deAllocFn} if !{@link isDispose} */ + def dispose(): Unit = { + if (!isDisposed) { + checkCall(checkdeAllocFn) + isDisposed = true + deRegister(this) + } + } + + override def close(): Unit = { + dispose() + } +} + +/** + * Fill me in + * @param h + * @param deAllocFn + */ +private[mxnet] class MXNetHandlePhantomRef(h: MXNetHandle, val deAllocFn: CPtrAddress => Int) + extends PhantomReference[MXNetHandle](h, refQ) { +} + +object MXNetHandlePhantomRef { + private val refQ: ReferenceQueue[MXNetHandle] = new ReferenceQueue[MXNetHandle] + + private val refs = new ConcurrentHashMap[MXNetHandlePhantomRef, CPtrAddress]() + + def register(referent: MXNetHandle, deAllocFn: CPtrAddress => Int): MXNetHandlePhantomRef = { + val ref = new MXNetHandlePhantomRef(referent, deAllocFn) + refs.put(ref, referent.nativeAddress) + ref + } + + def deRegister(referent: MXNetHandlePhantomRef): Unit = { + if ((r = refs.get(referent)) != null) { + refs.remove(referent) + } + } + + def cleanUp(): Unit = { + var ref: MXNetHandlePhantomRef = refQ.poll().asInstanceOf[MXNetHandlePhantomRef] + while (ref != null) { + // may be dispose or close was called on this + if ((hdl = refs.get(ref)) != null) { + ref.deAllocFn(hdl) + refs.remove(ref) + } + ref = refQ.poll().asInstanceOf[MXNetHandlePhantomRef] + } + } +} \ No newline at end of file From cba8a4397a5e5c2dff27b10a6056c288ad9d7691 Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Mon, 27 Aug 2018 14:28:22 -0700 Subject: [PATCH 02/21] use nswamy@ personal repo for mac testing --- .../init/src/main/scala/org/apache/mxnet/init/Base.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala b/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala index 7402dbd3bc1d..a43df10fcb7e 100644 --- a/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala +++ b/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala @@ -48,6 +48,7 @@ object Base { if (os.startsWith("Linux")) { System.load(s"$baseDir/linux-x86_64/target/libmxnet-init-scala-linux-x86_64.so") } else if (os.startsWith("Mac")) { + baseDir = "/Users/wamy/nswamy/deepengine/workspace/mxnet_scala/scala-package/init-native" System.load(s"$baseDir/osx-x86_64/target/libmxnet-init-scala-osx-x86_64.jnilib") } else { // TODO(yizhi) support windows later From 34106a40998520b994acdac7a769b6c03afcb00f Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Mon, 27 Aug 2018 17:46:49 -0700 Subject: [PATCH 03/21] Generic Handle with AutoCloseable --- .../org/apache/mxnet/MXNativeHandle.scala | 104 ++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 scala-package/core/src/main/scala/org/apache/mxnet/MXNativeHandle.scala diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/MXNativeHandle.scala b/scala-package/core/src/main/scala/org/apache/mxnet/MXNativeHandle.scala new file mode 100644 index 000000000000..5aaee7effaad --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/MXNativeHandle.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet + +import org.apache.mxnet.Base.CPtrAddress + +import java.lang.ref.{PhantomReference, ReferenceQueue} +import java.util.concurrent.ConcurrentHashMap +import org.apache.mxnet.Base.checkCall + +/** + * Should be generic to All MXNet Objects + * Should call the DeAlloc automatically + * Should do it only if not disposed. ie., dispose removes from the refQ + */ + +private[mxnet] trait MXNativeHandle extends AutoCloseable { + + def nativeAddress: CPtrAddress + + def nativeDeAllocAddress: (CPtrAddress => Int) + + val phantomRef: MXHandlePhantomRef + + def bytesAllocated: Long + + var isDisposed: Boolean = false + + def register(referent: MXNativeHandle): MXHandlePhantomRef = { + MXHandlePhantomRef.register(this, nativeDeAllocAddress) + } + + def deRegister(phantomRef: MXHandlePhantomRef): Unit = { + MXHandlePhantomRef.deRegister(phantomRef) + } + + /* call {@link deAllocFn} if !{@link isDispose} */ + def dispose(): Unit = { + print("dispose called") + if (!isDisposed) { + checkCall(nativeDeAllocAddress(this.nativeAddress)) + deRegister(phantomRef) + isDisposed = true + } + } + + override def close(): Unit = { + print("close called") + dispose() + } +} + +private[mxnet] class MXHandlePhantomRef(h: MXNativeHandle, val deAllocFn: CPtrAddress => Int) + extends PhantomReference[MXNativeHandle](h, MXHandlePhantomRef.phantomRefQ) { +} + +object MXHandlePhantomRef { + private val phantomRefQ: ReferenceQueue[MXNativeHandle] = new ReferenceQueue[MXNativeHandle] + + private val phantomRefMap = new ConcurrentHashMap[MXHandlePhantomRef, CPtrAddress]() + + def register(referent: MXNativeHandle, deAllocNativeAddr: CPtrAddress => Int): + MXHandlePhantomRef = { + val ref = new MXHandlePhantomRef(referent, deAllocNativeAddr) + phantomRefMap.put(ref, referent.nativeAddress) + ref + } + + def deRegister(phantomRef: MXHandlePhantomRef): Unit = { + val r = phantomRefMap.get(phantomRef) + if (r != null) { + phantomRefMap.remove(phantomRef) + } + } + + def cleanUp(): Unit = { + var ref: MXHandlePhantomRef = phantomRefQ.poll().asInstanceOf[MXHandlePhantomRef] + + while (ref != null) { + val hdl = phantomRefMap.get(ref) + // may be dispose or close was called on this + if (hdl != null) { + ref.deAllocFn(hdl) + phantomRefMap.remove(ref) + } + ref = phantomRefQ.poll().asInstanceOf[MXHandlePhantomRef] + } + } +} \ No newline at end of file From 373ac78b489c76f6a55807157b291a261592506b Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Wed, 29 Aug 2018 23:04:32 -0700 Subject: [PATCH 04/21] add NativeResource and NativeResourceManager with Periodic GC calling --- .../main/scala/org/apache/mxnet/Base.scala | 2 + .../org/apache/mxnet/MXNativeHandle.scala | 104 ----------- .../scala/org/apache/mxnet/MXNetHandle.scala | 101 ----------- .../org/apache/mxnet/NativeResource.scala | 162 ++++++++++++++++++ 4 files changed, 164 insertions(+), 205 deletions(-) delete mode 100644 scala-package/core/src/main/scala/org/apache/mxnet/MXNativeHandle.scala delete mode 100644 scala-package/core/src/main/scala/org/apache/mxnet/MXNetHandle.scala create mode 100644 scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala index b2a53fd9f2dd..62e0795c7ee3 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala @@ -78,6 +78,8 @@ private[mxnet] object Base { val _LIB = new LibInfo checkCall(_LIB.nativeLibInit()) + val resourceManager = NativeResourceManager.createPeriodicGCExecutor() + // TODO: shutdown hook won't work on Windows Runtime.getRuntime.addShutdownHook(new Thread() { override def run(): Unit = { diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/MXNativeHandle.scala b/scala-package/core/src/main/scala/org/apache/mxnet/MXNativeHandle.scala deleted file mode 100644 index 5aaee7effaad..000000000000 --- a/scala-package/core/src/main/scala/org/apache/mxnet/MXNativeHandle.scala +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mxnet - -import org.apache.mxnet.Base.CPtrAddress - -import java.lang.ref.{PhantomReference, ReferenceQueue} -import java.util.concurrent.ConcurrentHashMap -import org.apache.mxnet.Base.checkCall - -/** - * Should be generic to All MXNet Objects - * Should call the DeAlloc automatically - * Should do it only if not disposed. ie., dispose removes from the refQ - */ - -private[mxnet] trait MXNativeHandle extends AutoCloseable { - - def nativeAddress: CPtrAddress - - def nativeDeAllocAddress: (CPtrAddress => Int) - - val phantomRef: MXHandlePhantomRef - - def bytesAllocated: Long - - var isDisposed: Boolean = false - - def register(referent: MXNativeHandle): MXHandlePhantomRef = { - MXHandlePhantomRef.register(this, nativeDeAllocAddress) - } - - def deRegister(phantomRef: MXHandlePhantomRef): Unit = { - MXHandlePhantomRef.deRegister(phantomRef) - } - - /* call {@link deAllocFn} if !{@link isDispose} */ - def dispose(): Unit = { - print("dispose called") - if (!isDisposed) { - checkCall(nativeDeAllocAddress(this.nativeAddress)) - deRegister(phantomRef) - isDisposed = true - } - } - - override def close(): Unit = { - print("close called") - dispose() - } -} - -private[mxnet] class MXHandlePhantomRef(h: MXNativeHandle, val deAllocFn: CPtrAddress => Int) - extends PhantomReference[MXNativeHandle](h, MXHandlePhantomRef.phantomRefQ) { -} - -object MXHandlePhantomRef { - private val phantomRefQ: ReferenceQueue[MXNativeHandle] = new ReferenceQueue[MXNativeHandle] - - private val phantomRefMap = new ConcurrentHashMap[MXHandlePhantomRef, CPtrAddress]() - - def register(referent: MXNativeHandle, deAllocNativeAddr: CPtrAddress => Int): - MXHandlePhantomRef = { - val ref = new MXHandlePhantomRef(referent, deAllocNativeAddr) - phantomRefMap.put(ref, referent.nativeAddress) - ref - } - - def deRegister(phantomRef: MXHandlePhantomRef): Unit = { - val r = phantomRefMap.get(phantomRef) - if (r != null) { - phantomRefMap.remove(phantomRef) - } - } - - def cleanUp(): Unit = { - var ref: MXHandlePhantomRef = phantomRefQ.poll().asInstanceOf[MXHandlePhantomRef] - - while (ref != null) { - val hdl = phantomRefMap.get(ref) - // may be dispose or close was called on this - if (hdl != null) { - ref.deAllocFn(hdl) - phantomRefMap.remove(ref) - } - ref = phantomRefQ.poll().asInstanceOf[MXHandlePhantomRef] - } - } -} \ No newline at end of file diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/MXNetHandle.scala b/scala-package/core/src/main/scala/org/apache/mxnet/MXNetHandle.scala deleted file mode 100644 index 39046b584234..000000000000 --- a/scala-package/core/src/main/scala/org/apache/mxnet/MXNetHandle.scala +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mxnet - -import org.apache.mxnet.Base.CPtrAddress - -import scala.ref.{PhantomReference, ReferenceQueue} -import java.util.concurrent.ConcurrentHashMap -import org.apache.mxnet.Base.checkCall - -/** - * Should be generic to All MXNet Objects - * Should call the DeAlloc automatically - * Should do it only if not disposed. ie., dispose removes from the refQ - */ - -private[mxnet] trait MXNetHandle extends AutoCloseable { - - val nativeAddress: CPtrAddress - - val bytesAllocated: Long - - var isDisposed: Boolean = false - - val deAllocFn = (mxFreeHandleAddress: CPtrAddress => Int) - - def register(referent: MXNetHandle): Unit = { - MXNetHandlePhantomRef.register(this, deAllocFn) - } - - def deRegister(referent: MXNetHandle): Unit = { - MXNetHandlePhantomRef.deRegister(referent) - } - - /* call {@link deAllocFn} if !{@link isDispose} */ - def dispose(): Unit = { - if (!isDisposed) { - checkCall(checkdeAllocFn) - isDisposed = true - deRegister(this) - } - } - - override def close(): Unit = { - dispose() - } -} - -/** - * Fill me in - * @param h - * @param deAllocFn - */ -private[mxnet] class MXNetHandlePhantomRef(h: MXNetHandle, val deAllocFn: CPtrAddress => Int) - extends PhantomReference[MXNetHandle](h, refQ) { -} - -object MXNetHandlePhantomRef { - private val refQ: ReferenceQueue[MXNetHandle] = new ReferenceQueue[MXNetHandle] - - private val refs = new ConcurrentHashMap[MXNetHandlePhantomRef, CPtrAddress]() - - def register(referent: MXNetHandle, deAllocFn: CPtrAddress => Int): MXNetHandlePhantomRef = { - val ref = new MXNetHandlePhantomRef(referent, deAllocFn) - refs.put(ref, referent.nativeAddress) - ref - } - - def deRegister(referent: MXNetHandlePhantomRef): Unit = { - if ((r = refs.get(referent)) != null) { - refs.remove(referent) - } - } - - def cleanUp(): Unit = { - var ref: MXNetHandlePhantomRef = refQ.poll().asInstanceOf[MXNetHandlePhantomRef] - while (ref != null) { - // may be dispose or close was called on this - if ((hdl = refs.get(ref)) != null) { - ref.deAllocFn(hdl) - refs.remove(ref) - } - ref = refQ.poll().asInstanceOf[MXNetHandlePhantomRef] - } - } -} \ No newline at end of file diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala new file mode 100644 index 000000000000..0bc25b43196b --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet + +import org.apache.mxnet.Base.CPtrAddress +import java.lang.ref.{PhantomReference, ReferenceQueue} +import java.util.concurrent._ + +import org.apache.mxnet.Base.checkCall +import java.lang.AutoCloseable + +import scala.annotation.varargs +import org.slf4j.{Logger, LoggerFactory} + +import scala.util.Try + +trait NativeResourceManager extends AutoCloseable{ + override def close(): Unit = {} +} + +private[mxnet] object NativeResourceManager { + + // inspired from slide 21 of + def using[T <:NativeResource, U](resource: T)(block: T => U): U = { + try { + block(resource) + } finally { + // TODO(nswamy@): handle exceptions + if (resource != null) resource.close + } + } + + private val logger = LoggerFactory.getLogger(classOf[NativeResourceManager]) + + private val gcFrequencyInSecProp = "mxnet.gcFrequencyInSeconds" + private val gcAfterOffHeapBytesProp = "mxnet.gcAfterOffHeapBytes" + private val maxPhysicalBytesProp = "mxnet.maxPhysicalBytes" + + // ask Jonathan about Singletons + private var _scheduledExecutor: ScheduledExecutorService = null + + // set this to None at the end, so we don't run GC periodically by default + private val defaultGCFrequency = 5 + + private val periodicGCFrequency = Try(System.getProperty( + gcFrequencyInSecProp).toInt).getOrElse(defaultGCFrequency) + + def createPeriodicGCExecutor(): Unit = { + if (periodicGCFrequency != null && _scheduledExecutor == null) { + val scheduledExecutor: ScheduledExecutorService = + Executors.newSingleThreadScheduledExecutor(new ThreadFactory { + override def newThread(r: Runnable): Thread = new Thread(r) { + setName(classOf[NativeResourceManager].getCanonicalName) + setDaemon(true) + } + }) + scheduledExecutor.scheduleAtFixedRate(new Runnable { + override def run(): Unit = { + logger.info("Calling System.gc") + System.gc() + logger.info("Done Calling System.gc") + NativeResourcePhantomRef.cleanUp + logger.info("Done Cleaning up Native Resources") + } + }, + periodicGCFrequency, + periodicGCFrequency, + TimeUnit.SECONDS + ) + _scheduledExecutor = scheduledExecutor + } + } +} + +private[mxnet] trait NativeResource extends AutoCloseable { + + def nativeAddress: CPtrAddress + + def nativeDeAllocAddress: (CPtrAddress => Int) + + val phantomRef: NativeResourcePhantomRef + + def bytesAllocated: Long + + var isDisposed: Boolean = false + + def register(referent: NativeResource): NativeResourcePhantomRef = { + NativeResourcePhantomRef.register(this, nativeDeAllocAddress) + } + + def deRegister(phantomRef: NativeResourcePhantomRef): Unit = { + NativeResourcePhantomRef.deRegister(phantomRef) + } + + /* call {@link deAllocFn} if !{@link isDispose} */ + def dispose(): Unit = { + print("dispose called\n") + if (!isDisposed) { + checkCall(nativeDeAllocAddress(this.nativeAddress)) + deRegister(phantomRef) + isDisposed = true + } + } + + override def close(): Unit = { + print("close called\n") + dispose() + } +} + +private[mxnet] class NativeResourcePhantomRef(h: NativeResource, val deAllocFn: CPtrAddress => Int) + extends PhantomReference[NativeResource](h, NativeResourcePhantomRef.phantomRefQ) { +} + +private[mxnet] object NativeResourcePhantomRef { + private val phantomRefQ: ReferenceQueue[NativeResource] = new ReferenceQueue[NativeResource] + + private val phantomRefMap = new ConcurrentHashMap[NativeResourcePhantomRef, CPtrAddress]() + + def register(referent: NativeResource, deAllocNativeAddr: CPtrAddress => Int): + NativeResourcePhantomRef = { + val ref = new NativeResourcePhantomRef(referent, deAllocNativeAddr) + phantomRefMap.put(ref, referent.nativeAddress) + ref + } + + def deRegister(phantomRef: NativeResourcePhantomRef): Unit = { + val r = phantomRefMap.get(phantomRef) + if (r != null) { + phantomRefMap.remove(phantomRef) + } + } + + def cleanUp(): Unit = { + var ref: NativeResourcePhantomRef = phantomRefQ.poll().asInstanceOf[NativeResourcePhantomRef] + + while (ref != null) { + val hdl = phantomRefMap.get(ref) + // may be dispose or close was called on this + if (hdl != null) { + ref.deAllocFn(hdl) + phantomRefMap.remove(ref) + } + ref = phantomRefQ.poll().asInstanceOf[NativeResourcePhantomRef] + } + } +} \ No newline at end of file From e0016d7cccb046940f4bd001e809e3aadf0790a3 Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Wed, 29 Aug 2018 23:06:05 -0700 Subject: [PATCH 05/21] use NativeResource trait in NDArray, Symbol and Executor --- .../scala/org/apache/mxnet/Executor.scala | 17 +++++++++-- .../main/scala/org/apache/mxnet/NDArray.scala | 29 ++++++++++++------- .../org/apache/mxnet/NativeResource.scala | 2 -- .../main/scala/org/apache/mxnet/Symbol.scala | 16 ++++++++-- 4 files changed, 47 insertions(+), 17 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala index fc791d5cd9a3..a83ad2c77763 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala @@ -45,7 +45,17 @@ object Executor { * @see Symbol.bind : to create executor */ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, - private[mxnet] val symbol: Symbol) extends WarnIfNotDisposed { + private[mxnet] val symbol: Symbol) + extends WarnIfNotDisposed with NativeResource { + + override def nativeAddress: CPtrAddress = handle + + override def nativeDeAllocAddress: CPtrAddress => Int = _LIB.mxExecutorFree + + override val phantomRef: NativeResourcePhantomRef = super.register(this) + + override def bytesAllocated: Long = 0 + private[mxnet] var argArrays: Array[NDArray] = null private[mxnet] var gradArrays: Array[NDArray] = null private[mxnet] var auxArrays: Array[NDArray] = null @@ -60,8 +70,8 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, private val logger: Logger = LoggerFactory.getLogger(classOf[Executor]) private var disposed = false - protected def isDisposed = disposed - +// protected def isDisposed = disposed +/* def dispose(): Unit = { if (!disposed) { outputs.foreach(_.dispose()) @@ -69,6 +79,7 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, disposed = true } } +*/ /** * Return a new executor with the same symbol and shared memory, diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 9b6a7dc66540..269fe7a96eca 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -562,16 +562,24 @@ object NDArray extends NDArrayBase { */ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, val writable: Boolean = true, - addToCollector: Boolean = true) extends WarnIfNotDisposed { + addToCollector: Boolean = true) + extends WarnIfNotDisposed with NativeResource { if (addToCollector) { NDArrayCollector.collect(this) } + override def nativeAddress: CPtrAddress = handle + + override def nativeDeAllocAddress: CPtrAddress => Int = _LIB.mxNDArrayFree + + override val phantomRef: NativeResourcePhantomRef = super.register(this) + + override def bytesAllocated: Long = DType.numOfBytes(this.dtype) * this.shape.product + // record arrays who construct this array instance // we use weak reference to prevent gc blocking private[mxnet] val dependencies = mutable.HashMap.empty[Long, WeakReference[NDArray]] - @volatile private var disposed = false - def isDisposed: Boolean = disposed + @volatile private var disposed = isDisposed def serialize(): Array[Byte] = { val buf = ArrayBuffer.empty[Byte] @@ -584,13 +592,13 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, * The NDArrays it depends on will NOT be disposed.
* The object shall never be used after it is disposed. */ - def dispose(): Unit = { - if (!disposed) { - _LIB.mxNDArrayFree(handle) - dependencies.clear() - disposed = true - } - } +// def dispose(): Unit = { +// if (!disposed) { +// _LIB.mxNDArrayFree(handle) +// dependencies.clear() +// disposed = true +// } +// } /** * Dispose all NDArrays who help to construct this array.
@@ -1034,6 +1042,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, // TODO: naive implementation shape.hashCode + toArray.hashCode } + } private[mxnet] object NDArrayConversions { diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala index 0bc25b43196b..96ffc3bb5db6 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala @@ -109,7 +109,6 @@ private[mxnet] trait NativeResource extends AutoCloseable { /* call {@link deAllocFn} if !{@link isDispose} */ def dispose(): Unit = { - print("dispose called\n") if (!isDisposed) { checkCall(nativeDeAllocAddress(this.nativeAddress)) deRegister(phantomRef) @@ -118,7 +117,6 @@ private[mxnet] trait NativeResource extends AutoCloseable { } override def close(): Unit = { - print("close called\n") dispose() } } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala index b1a3e392f41e..8519a458f570 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala @@ -29,21 +29,33 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer} * WARNING: it is your responsibility to clear this object through dispose(). * */ -class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotDisposed { +class Symbol private(private[mxnet] val handle: SymbolHandle) + extends WarnIfNotDisposed with NativeResource { + + override def nativeAddress: CPtrAddress = handle + + override def nativeDeAllocAddress: CPtrAddress => Int = _LIB.mxSymbolFree + + override val phantomRef: NativeResourcePhantomRef = super.register(this) + + override def bytesAllocated: Long = 0 + private val logger: Logger = LoggerFactory.getLogger(classOf[Symbol]) private var disposed = false - protected def isDisposed = disposed +// protected def isDisposed = disposed /** * Release the native memory. * The object shall never be used after it is disposed. */ +/* def dispose(): Unit = { if (!disposed) { _LIB.mxSymbolFree(handle) disposed = true } } +*/ def +(other: Symbol): Symbol = Symbol.createFromListedSymbols("_Plus")(Array(this, other)) def +[@specialized(Int, Float, Double) V](other: V): Symbol = { From 5cd3cd354ba0f0ee401ae7996060e2a245082afb Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Wed, 29 Aug 2018 23:20:39 -0700 Subject: [PATCH 06/21] add run train mnist script --- .../examples/scripts/run_train_mnist.sh | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100755 scala-package/examples/scripts/run_train_mnist.sh diff --git a/scala-package/examples/scripts/run_train_mnist.sh b/scala-package/examples/scripts/run_train_mnist.sh new file mode 100755 index 000000000000..ea53c1ade66f --- /dev/null +++ b/scala-package/examples/scripts/run_train_mnist.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e + +MXNET_ROOT=$(cd "$(dirname $0)/../../.."; pwd) +echo $MXNET_ROOT +CLASS_PATH=$MXNET_ROOT/scala-package/assembly/linux-x86_64-cpu/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/* + +# model dir +DATA_PATH=$2 + +java -XX:+PrintGC -Xms256M -Xmx512M -Dmxnet.traceLeakedObjects=false -cp $CLASS_PATH \ + org.apache.mxnetexamples.imclassification.TrainMnist \ + --data-dir /home/ubuntu/mxnet_scala/scala-package/examples/mnist/ \ + --num-epochs 10000000 \ + --batch-size 1024 \ No newline at end of file From ef4bfe839c6457987d61871e777cf2ae1ba62e72 Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Tue, 4 Sep 2018 11:07:15 -0700 Subject: [PATCH 07/21] create a Generic ResourceScope that can collect all NativeResources to dispose at the end --- .../main/scala/org/apache/mxnet/Base.scala | 2 +- .../scala/org/apache/mxnet/Executor.scala | 2 +- .../main/scala/org/apache/mxnet/NDArray.scala | 12 +- .../org/apache/mxnet/NativeResource.scala | 206 +++++++++++++----- .../main/scala/org/apache/mxnet/Symbol.scala | 2 +- 5 files changed, 163 insertions(+), 61 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala index 62e0795c7ee3..0cdb492bb8d2 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala @@ -78,7 +78,7 @@ private[mxnet] object Base { val _LIB = new LibInfo checkCall(_LIB.nativeLibInit()) - val resourceManager = NativeResourceManager.createPeriodicGCExecutor() + val resourceManager = PeriodicGCDeAllocator.createPeriodicGCExecutor() // TODO: shutdown hook won't work on Windows Runtime.getRuntime.addShutdownHook(new Thread() { diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala index a83ad2c77763..825cc096759e 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala @@ -52,7 +52,7 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, override def nativeDeAllocAddress: CPtrAddress => Int = _LIB.mxExecutorFree - override val phantomRef: NativeResourcePhantomRef = super.register(this) + override val phantomRef: NativeResourceRef = super.register(this) override def bytesAllocated: Long = 0 diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 269fe7a96eca..0e3ab66e00be 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -563,22 +563,28 @@ object NDArray extends NDArrayBase { class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, val writable: Boolean = true, addToCollector: Boolean = true) - extends WarnIfNotDisposed with NativeResource { + extends WarnIfNotDisposed // { + with NativeResource { if (addToCollector) { NDArrayCollector.collect(this) } + override def nativeAddress: CPtrAddress = handle override def nativeDeAllocAddress: CPtrAddress => Int = _LIB.mxNDArrayFree - override val phantomRef: NativeResourcePhantomRef = super.register(this) + override val phantomRef: NativeResourceRef = super.register(this) override def bytesAllocated: Long = DType.numOfBytes(this.dtype) * this.shape.product + // record arrays who construct this array instance // we use weak reference to prevent gc blocking private[mxnet] val dependencies = mutable.HashMap.empty[Long, WeakReference[NDArray]] + +// @volatile private var disposed = false +// def isDisposed: Boolean = disposed @volatile private var disposed = isDisposed def serialize(): Array[Byte] = { @@ -593,7 +599,9 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, * The object shall never be used after it is disposed. */ // def dispose(): Unit = { +// print("dispose\n") // if (!disposed) { +// print("disposing\n") // _LIB.mxNDArrayFree(handle) // dependencies.clear() // disposed = true diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala index 96ffc3bb5db6..4ebc94af9aa0 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala @@ -18,40 +18,28 @@ package org.apache.mxnet import org.apache.mxnet.Base.CPtrAddress -import java.lang.ref.{PhantomReference, ReferenceQueue} +import java.lang.ref.{WeakReference, PhantomReference, ReferenceQueue} import java.util.concurrent._ import org.apache.mxnet.Base.checkCall -import java.lang.AutoCloseable +import java.lang.{AutoCloseable, ThreadLocal} -import scala.annotation.varargs import org.slf4j.{Logger, LoggerFactory} +import scala.collection.mutable.{ArrayBuffer, ArrayStack} import scala.util.Try -trait NativeResourceManager extends AutoCloseable{ - override def close(): Unit = {} -} +private[mxnet] class PeriodicGCDeAllocator { -private[mxnet] object NativeResourceManager { +} - // inspired from slide 21 of - def using[T <:NativeResource, U](resource: T)(block: T => U): U = { - try { - block(resource) - } finally { - // TODO(nswamy@): handle exceptions - if (resource != null) resource.close - } - } +private[mxnet] object PeriodicGCDeAllocator { - private val logger = LoggerFactory.getLogger(classOf[NativeResourceManager]) + private val logger = LoggerFactory.getLogger(classOf[PeriodicGCDeAllocator]) private val gcFrequencyInSecProp = "mxnet.gcFrequencyInSeconds" private val gcAfterOffHeapBytesProp = "mxnet.gcAfterOffHeapBytes" private val maxPhysicalBytesProp = "mxnet.maxPhysicalBytes" - - // ask Jonathan about Singletons private var _scheduledExecutor: ScheduledExecutorService = null // set this to None at the end, so we don't run GC periodically by default @@ -65,7 +53,7 @@ private[mxnet] object NativeResourceManager { val scheduledExecutor: ScheduledExecutorService = Executors.newSingleThreadScheduledExecutor(new ThreadFactory { override def newThread(r: Runnable): Thread = new Thread(r) { - setName(classOf[NativeResourceManager].getCanonicalName) + setName(classOf[ResourceScope].getCanonicalName) setDaemon(true) } }) @@ -74,8 +62,6 @@ private[mxnet] object NativeResourceManager { logger.info("Calling System.gc") System.gc() logger.info("Done Calling System.gc") - NativeResourcePhantomRef.cleanUp - logger.info("Done Cleaning up Native Resources") } }, periodicGCFrequency, @@ -87,74 +73,182 @@ private[mxnet] object NativeResourceManager { } } +class ResourceScope extends AutoCloseable { + import ResourceScope.{logger, resourceScope} + + private val resourceQ = new ArrayBuffer[NativeResource]() + resourceScope.get().+=(this) + + override def close(): Unit = { + resourceQ.foreach(resource => if (resource != null) { + logger.info("releasing resource:%x\n".format(resource.nativeAddress)) + resource.dispose() + resource.deRegister(false) + } else {logger.info("found resource which is null")} + ) + ResourceScope.resourceScope.get().-=(this) + } + + private[mxnet] def register(resource: NativeResource): Unit = { + logger.info("ResourceScope: Registering Resource %x".format(resource.nativeAddress)) + resourceQ.+=(resource) + } + + // TODO(@nswamy): this is linear in time, find better data structure + private[mxnet] def deRegister(resource: NativeResource): Unit = { + logger.info("ResourceScope: DeRegistering Resource %x".format(resource.nativeAddress)) + resourceQ.-=(resource) + } +} + + object ResourceScope { + + private val logger = LoggerFactory.getLogger(classOf[ResourceScope]) + + // inspired from slide 21 of + def using[T](resource: ResourceScope)(block: => T): T = { + require(resource != null) + try { + val ret = block + ret match { + case nRes: NativeResource => + resource.deRegister(nRes.asInstanceOf[NativeResource]) + case _ => // do nothing + } + ret + } finally { + // TODO(nswamy@): handle exceptions + resource.close + } + } + + private[mxnet] val resourceScope = new ThreadLocal[ArrayBuffer[ResourceScope]] { + override def initialValue(): ArrayBuffer[ResourceScope] = + new ArrayBuffer[ResourceScope]() + } + + private[mxnet] def getScope(): ResourceScope = { + try { + resourceScope.get().last + } catch { + case _: ArrayIndexOutOfBoundsException => null + case _: NoSuchElementException => null + case e: Exception => throw e + } + } +} + private[mxnet] trait NativeResource extends AutoCloseable { def nativeAddress: CPtrAddress def nativeDeAllocAddress: (CPtrAddress => Int) - val phantomRef: NativeResourcePhantomRef + /** Call {@link NativeResource.register} to get NativeResourcePhantomRef + * + */ + val phantomRef: NativeResourceRef def bytesAllocated: Long var isDisposed: Boolean = false - def register(referent: NativeResource): NativeResourcePhantomRef = { - NativeResourcePhantomRef.register(this, nativeDeAllocAddress) + private var scope: ResourceScope = null + + def register(referent: NativeResource): NativeResourceRef = { + scope = ResourceScope.getScope() + if (scope != null) { + scope.register(this) + } + // register with PhantomRef tracking to release incase the objects go + // out of reference within scope but are held for long time + NativeResourceRef.register(this, nativeDeAllocAddress) } - def deRegister(phantomRef: NativeResourcePhantomRef): Unit = { - NativeResourcePhantomRef.deRegister(phantomRef) + /** + * remove from PhantomRef tracking and + * ResourceScope tracking + */ + def deRegister(removeFromScope: Boolean = true): Unit = { + NativeResourceRef.deRegister(phantomRef) + if (scope != null && removeFromScope) scope.deRegister(this) + } + + override def close(): Unit = { + dispose() + deRegister(true) } /* call {@link deAllocFn} if !{@link isDispose} */ - def dispose(): Unit = { + final def dispose(): Unit = { if (!isDisposed) { + print("NativeResource: Disposing NativeResource:%x\n".format(nativeAddress)) checkCall(nativeDeAllocAddress(this.nativeAddress)) - deRegister(phantomRef) isDisposed = true } } - - override def close(): Unit = { - dispose() - } } -private[mxnet] class NativeResourcePhantomRef(h: NativeResource, val deAllocFn: CPtrAddress => Int) - extends PhantomReference[NativeResource](h, NativeResourcePhantomRef.phantomRefQ) { +// do not make nativeRes a member, this will hold reference and GC will not clear the object. +private[mxnet] class NativeResourceRef(resource: NativeResource, + val resDeAllocAddr: CPtrAddress => Int) + extends PhantomReference[NativeResource](resource, NativeResourceRef.referenceQueue) { } -private[mxnet] object NativeResourcePhantomRef { - private val phantomRefQ: ReferenceQueue[NativeResource] = new ReferenceQueue[NativeResource] +private[mxnet] object NativeResourceRef { + + private val referenceQueue: ReferenceQueue[NativeResource] = new ReferenceQueue[NativeResource] + + private val phantomRefMap = new ConcurrentHashMap[NativeResourceRef, CPtrAddress]() + + private val cleanupThread = new ResourceCleanupThread() - private val phantomRefMap = new ConcurrentHashMap[NativeResourcePhantomRef, CPtrAddress]() + cleanupThread.start() - def register(referent: NativeResource, deAllocNativeAddr: CPtrAddress => Int): - NativeResourcePhantomRef = { - val ref = new NativeResourcePhantomRef(referent, deAllocNativeAddr) - phantomRefMap.put(ref, referent.nativeAddress) - ref + def register(resource: NativeResource, resDeAllocAddr: CPtrAddress => Int): + NativeResourceRef = { + val resourceRef = new NativeResourceRef(resource, resDeAllocAddr) + phantomRefMap.put(resourceRef, resource.nativeAddress) + resourceRef } - def deRegister(phantomRef: NativeResourcePhantomRef): Unit = { - val r = phantomRefMap.get(phantomRef) - if (r != null) { - phantomRefMap.remove(phantomRef) + def deRegister(resourceRef: NativeResourceRef): Unit = { + val resDeAllocAddr = phantomRefMap.get(resourceRef) + if (resDeAllocAddr != null) { + phantomRefMap.remove(resourceRef) } } - def cleanUp(): Unit = { - var ref: NativeResourcePhantomRef = phantomRefQ.poll().asInstanceOf[NativeResourcePhantomRef] + def cleanup(): Unit = { + print("NativeResourceRef: cleanup\n") + // remove is a blocking call + val ref: NativeResourceRef = referenceQueue.remove().asInstanceOf[NativeResourceRef] + print("NativeResourceRef: got a reference with deAlloc\n") + // phantomRef will be removed from the map when NativeResource.close is called. + val resource = phantomRefMap.get(ref) + + if (resource != null) { + print("NativeResourceRef: got a reference for resource\n") + ref.resDeAllocAddr(resource) + phantomRefMap.remove(ref) + } + } + + private class ResourceCleanupThread extends Thread { + setPriority(Thread.MAX_PRIORITY) + setName("NativeResourceDeAllocatorThread") + setDaemon(true) - while (ref != null) { - val hdl = phantomRefMap.get(ref) - // may be dispose or close was called on this - if (hdl != null) { - ref.deAllocFn(hdl) - phantomRefMap.remove(ref) + override def run(): Unit = { + while (true) { + try { + cleanup() + } + catch { + case _: InterruptedException => Thread.currentThread().interrupt() + } } - ref = phantomRefQ.poll().asInstanceOf[NativeResourcePhantomRef] } } + } \ No newline at end of file diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala index 8519a458f570..ccf4a833aa96 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala @@ -36,7 +36,7 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) override def nativeDeAllocAddress: CPtrAddress => Int = _LIB.mxSymbolFree - override val phantomRef: NativeResourcePhantomRef = super.register(this) + override val phantomRef: NativeResourceRef = super.register(this) override def bytesAllocated: Long = 0 From c04e4f0f96d4cbcb0eb937f1a386a013ebe10671 Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Fri, 7 Sep 2018 00:21:31 -0700 Subject: [PATCH 08/21] modify NativeResource and ResourceScope, extend NativeResource in NDArray, Symbol and Executor --- .../scala/org/apache/mxnet/Executor.scala | 29 +-- .../main/scala/org/apache/mxnet/NDArray.scala | 31 +-- .../org/apache/mxnet/NativeResource.scala | 194 +++++------------- .../org/apache/mxnet/ResourceScope.scala | 86 ++++++++ .../main/scala/org/apache/mxnet/Symbol.scala | 32 +-- 5 files changed, 167 insertions(+), 205 deletions(-) create mode 100644 scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala index 825cc096759e..def97327a2e9 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala @@ -45,17 +45,7 @@ object Executor { * @see Symbol.bind : to create executor */ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, - private[mxnet] val symbol: Symbol) - extends WarnIfNotDisposed with NativeResource { - - override def nativeAddress: CPtrAddress = handle - - override def nativeDeAllocAddress: CPtrAddress => Int = _LIB.mxExecutorFree - - override val phantomRef: NativeResourceRef = super.register(this) - - override def bytesAllocated: Long = 0 - + private[mxnet] val symbol: Symbol) extends NativeResource { private[mxnet] var argArrays: Array[NDArray] = null private[mxnet] var gradArrays: Array[NDArray] = null private[mxnet] var auxArrays: Array[NDArray] = null @@ -69,17 +59,11 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, private[mxnet] var _group2ctx: Map[String, Context] = null private val logger: Logger = LoggerFactory.getLogger(classOf[Executor]) - private var disposed = false -// protected def isDisposed = disposed -/* - def dispose(): Unit = { - if (!disposed) { - outputs.foreach(_.dispose()) - _LIB.mxExecutorFree(handle) - disposed = true - } - } -*/ + override def nativeAddress: CPtrAddress = handle + override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxExecutorFree + // cannot determine the off-heap size of this object + override def bytesAllocated: Long = 0 + override val phantomRef: NativeResourceRef = super.register() /** * Return a new executor with the same symbol and shared memory, @@ -316,4 +300,5 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, checkCall(_LIB.mxExecutorPrint(handle, str)) str.value } + } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 0e3ab66e00be..77b1f60aba3a 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -562,31 +562,21 @@ object NDArray extends NDArrayBase { */ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, val writable: Boolean = true, - addToCollector: Boolean = true) - extends WarnIfNotDisposed // { - with NativeResource { + addToCollector: Boolean = true) extends NativeResource { if (addToCollector) { NDArrayCollector.collect(this) } - override def nativeAddress: CPtrAddress = handle - - override def nativeDeAllocAddress: CPtrAddress => Int = _LIB.mxNDArrayFree - - override val phantomRef: NativeResourceRef = super.register(this) - + override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxNDArrayFree override def bytesAllocated: Long = DType.numOfBytes(this.dtype) * this.shape.product + override val phantomRef: NativeResourceRef = super.register() // record arrays who construct this array instance // we use weak reference to prevent gc blocking private[mxnet] val dependencies = mutable.HashMap.empty[Long, WeakReference[NDArray]] -// @volatile private var disposed = false -// def isDisposed: Boolean = disposed - @volatile private var disposed = isDisposed - def serialize(): Array[Byte] = { val buf = ArrayBuffer.empty[Byte] checkCall(_LIB.mxNDArraySaveRawBytes(handle, buf)) @@ -598,15 +588,12 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, * The NDArrays it depends on will NOT be disposed.
* The object shall never be used after it is disposed. */ -// def dispose(): Unit = { -// print("dispose\n") -// if (!disposed) { -// print("disposing\n") -// _LIB.mxNDArrayFree(handle) -// dependencies.clear() -// disposed = true -// } -// } + override def dispose(): Unit = { + if (!disposed) { + super.dispose() + dependencies.clear() + } + } /** * Dispose all NDArrays who help to construct this array.
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala index 4ebc94af9aa0..edc6901dcc4e 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala @@ -18,180 +18,101 @@ package org.apache.mxnet import org.apache.mxnet.Base.CPtrAddress -import java.lang.ref.{WeakReference, PhantomReference, ReferenceQueue} +import java.lang.ref.{PhantomReference, ReferenceQueue, WeakReference} import java.util.concurrent._ import org.apache.mxnet.Base.checkCall import java.lang.{AutoCloseable, ThreadLocal} - +import java.util.concurrent.atomic.AtomicLong import org.slf4j.{Logger, LoggerFactory} -import scala.collection.mutable.{ArrayBuffer, ArrayStack} -import scala.util.Try - -private[mxnet] class PeriodicGCDeAllocator { - -} - -private[mxnet] object PeriodicGCDeAllocator { - - private val logger = LoggerFactory.getLogger(classOf[PeriodicGCDeAllocator]) - - private val gcFrequencyInSecProp = "mxnet.gcFrequencyInSeconds" - private val gcAfterOffHeapBytesProp = "mxnet.gcAfterOffHeapBytes" - private val maxPhysicalBytesProp = "mxnet.maxPhysicalBytes" - private var _scheduledExecutor: ScheduledExecutorService = null - - // set this to None at the end, so we don't run GC periodically by default - private val defaultGCFrequency = 5 - - private val periodicGCFrequency = Try(System.getProperty( - gcFrequencyInSecProp).toInt).getOrElse(defaultGCFrequency) - - def createPeriodicGCExecutor(): Unit = { - if (periodicGCFrequency != null && _scheduledExecutor == null) { - val scheduledExecutor: ScheduledExecutorService = - Executors.newSingleThreadScheduledExecutor(new ThreadFactory { - override def newThread(r: Runnable): Thread = new Thread(r) { - setName(classOf[ResourceScope].getCanonicalName) - setDaemon(true) - } - }) - scheduledExecutor.scheduleAtFixedRate(new Runnable { - override def run(): Unit = { - logger.info("Calling System.gc") - System.gc() - logger.info("Done Calling System.gc") - } - }, - periodicGCFrequency, - periodicGCFrequency, - TimeUnit.SECONDS - ) - _scheduledExecutor = scheduledExecutor - } - } -} - -class ResourceScope extends AutoCloseable { - import ResourceScope.{logger, resourceScope} - - private val resourceQ = new ArrayBuffer[NativeResource]() - resourceScope.get().+=(this) - - override def close(): Unit = { - resourceQ.foreach(resource => if (resource != null) { - logger.info("releasing resource:%x\n".format(resource.nativeAddress)) - resource.dispose() - resource.deRegister(false) - } else {logger.info("found resource which is null")} - ) - ResourceScope.resourceScope.get().-=(this) - } - - private[mxnet] def register(resource: NativeResource): Unit = { - logger.info("ResourceScope: Registering Resource %x".format(resource.nativeAddress)) - resourceQ.+=(resource) - } - - // TODO(@nswamy): this is linear in time, find better data structure - private[mxnet] def deRegister(resource: NativeResource): Unit = { - logger.info("ResourceScope: DeRegistering Resource %x".format(resource.nativeAddress)) - resourceQ.-=(resource) - } -} - - object ResourceScope { - - private val logger = LoggerFactory.getLogger(classOf[ResourceScope]) - - // inspired from slide 21 of - def using[T](resource: ResourceScope)(block: => T): T = { - require(resource != null) - try { - val ret = block - ret match { - case nRes: NativeResource => - resource.deRegister(nRes.asInstanceOf[NativeResource]) - case _ => // do nothing - } - ret - } finally { - // TODO(nswamy@): handle exceptions - resource.close - } - } - - private[mxnet] val resourceScope = new ThreadLocal[ArrayBuffer[ResourceScope]] { - override def initialValue(): ArrayBuffer[ResourceScope] = - new ArrayBuffer[ResourceScope]() - } - - private[mxnet] def getScope(): ResourceScope = { - try { - resourceScope.get().last - } catch { - case _: ArrayIndexOutOfBoundsException => null - case _: NoSuchElementException => null - case e: Exception => throw e - } - } -} - -private[mxnet] trait NativeResource extends AutoCloseable { +/** + * NativeResource trait is used to manage MXNet Objects + * such as NDArray, Symbol, Executor, etc., + * The MXNet Object calls {@link NativeResource.register} + * and assign the returned NativeResourceRef to {@link phantomRef} + * NativeResource also implements AutoCloseable so MXNetObjects + * can be used like Resources in try-with-resources paradigm + */ +private[mxnet] trait NativeResource + extends AutoCloseable with WarnIfNotDisposed { + /** + * native Address associated with this object + * @return + */ def nativeAddress: CPtrAddress - def nativeDeAllocAddress: (CPtrAddress => Int) + /** + * Function Pointer to the NativeDeAllocator of {@link nativeAddress} + * @return + */ + def nativeDeAllocator: (CPtrAddress => Int) - /** Call {@link NativeResource.register} to get NativeResourcePhantomRef - * - */ + /** Call {@link NativeResource.register} to get {@link NativeResourceRef} + */ val phantomRef: NativeResourceRef + /** + * Off-Heap Bytes Allocated for this object + * @return + */ def bytesAllocated: Long - var isDisposed: Boolean = false - private var scope: ResourceScope = null - def register(referent: NativeResource): NativeResourceRef = { + @volatile var disposed = false + + override def isDisposed: Boolean = disposed + /** + * Register this object for PhantomReference tracking and within + * ResourceScope if used inside ResourceScope. + * @return NativeResourceRef that tracks reachability of this object + * using PhantomReference + */ + def register(): NativeResourceRef = { + scope = ResourceScope.getScope() if (scope != null) { scope.register(this) } // register with PhantomRef tracking to release incase the objects go // out of reference within scope but are held for long time - NativeResourceRef.register(this, nativeDeAllocAddress) + NativeResourceRef.register(this, nativeDeAllocator) } /** - * remove from PhantomRef tracking and - * ResourceScope tracking + * Removes this object from PhantomRef tracking and from ResourceScope + * @param removeFromScope */ def deRegister(removeFromScope: Boolean = true): Unit = { NativeResourceRef.deRegister(phantomRef) if (scope != null && removeFromScope) scope.deRegister(this) } + /** + * Implements {@link AutoCloseable.close} + */ override def close(): Unit = { dispose() deRegister(true) } - /* call {@link deAllocFn} if !{@link isDispose} */ - final def dispose(): Unit = { - if (!isDisposed) { + /** + * Implements {@link WarnIfNotDisposed.dispose} + */ + def dispose(): Unit = { + if (!disposed) { print("NativeResource: Disposing NativeResource:%x\n".format(nativeAddress)) - checkCall(nativeDeAllocAddress(this.nativeAddress)) - isDisposed = true + checkCall(nativeDeAllocator(this.nativeAddress)) + disposed = true } } } // do not make nativeRes a member, this will hold reference and GC will not clear the object. private[mxnet] class NativeResourceRef(resource: NativeResource, - val resDeAllocAddr: CPtrAddress => Int) + val resourceDeAllocator: CPtrAddress => Int) extends PhantomReference[NativeResource](resource, NativeResourceRef.referenceQueue) { } @@ -205,16 +126,16 @@ private[mxnet] object NativeResourceRef { cleanupThread.start() - def register(resource: NativeResource, resDeAllocAddr: CPtrAddress => Int): + def register(resource: NativeResource, nativeDeAllocator: (CPtrAddress => Int)): NativeResourceRef = { - val resourceRef = new NativeResourceRef(resource, resDeAllocAddr) + val resourceRef = new NativeResourceRef(resource, nativeDeAllocator) phantomRefMap.put(resourceRef, resource.nativeAddress) resourceRef } def deRegister(resourceRef: NativeResourceRef): Unit = { - val resDeAllocAddr = phantomRefMap.get(resourceRef) - if (resDeAllocAddr != null) { + val nativeDeAllocator = phantomRefMap.get(resourceRef) + if (nativeDeAllocator != null) { phantomRefMap.remove(resourceRef) } } @@ -229,7 +150,7 @@ private[mxnet] object NativeResourceRef { if (resource != null) { print("NativeResourceRef: got a reference for resource\n") - ref.resDeAllocAddr(resource) + ref.resourceDeAllocator(resource) phantomRefMap.remove(ref) } } @@ -250,5 +171,4 @@ private[mxnet] object NativeResourceRef { } } } - } \ No newline at end of file diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala new file mode 100644 index 000000000000..eb598405adc2 --- /dev/null +++ b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet + +import org.slf4j.LoggerFactory +import scala.collection.mutable.ArrayBuffer + +class ResourceScope extends AutoCloseable { + import ResourceScope.{logger, resourceScope} + + private val resourceQ = new ArrayBuffer[NativeResource]() + resourceScope.get().+=(this) + + override def close(): Unit = { + resourceQ.foreach(resource => if (resource != null) { + logger.info("releasing resource:%x\n".format(resource.nativeAddress)) + resource.dispose() + resource.deRegister(false) + } else {logger.info("found resource which is null")} + ) + ResourceScope.resourceScope.get().-=(this) + } + + private[mxnet] def register(resource: NativeResource): Unit = { + logger.info("ResourceScope: Registering Resource %x".format(resource.nativeAddress)) + resourceQ.+=(resource) + } + + // TODO(@nswamy): this is linear in time, find better data structure + private[mxnet] def deRegister(resource: NativeResource): Unit = { + logger.info("ResourceScope: DeRegistering Resource %x".format(resource.nativeAddress)) + resourceQ.-=(resource) + } +} + +object ResourceScope { + + private val logger = LoggerFactory.getLogger(classOf[ResourceScope]) + + // inspired from slide 21 of + def using[T](resource: ResourceScope)(block: => T): T = { + require(resource != null) + try { + val ret = block + ret match { + case nRes: NativeResource => + resource.deRegister(nRes.asInstanceOf[NativeResource]) + case _ => // do nothing + } + ret + } finally { + // TODO(nswamy@): handle exceptions + resource.close + } + } + + private[mxnet] val resourceScope = new ThreadLocal[ArrayBuffer[ResourceScope]] { + override def initialValue(): ArrayBuffer[ResourceScope] = + new ArrayBuffer[ResourceScope]() + } + + private[mxnet] def getScope(): ResourceScope = { + try { + resourceScope.get().last + } catch { + case _: ArrayIndexOutOfBoundsException => null + case _: NoSuchElementException => null + case e: Exception => throw e + } + } +} \ No newline at end of file diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala index ccf4a833aa96..7dd86720b6b0 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala @@ -29,33 +29,16 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer} * WARNING: it is your responsibility to clear this object through dispose(). * */ -class Symbol private(private[mxnet] val handle: SymbolHandle) - extends WarnIfNotDisposed with NativeResource { - - override def nativeAddress: CPtrAddress = handle - - override def nativeDeAllocAddress: CPtrAddress => Int = _LIB.mxSymbolFree +class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeResource { + private val logger: Logger = LoggerFactory.getLogger(classOf[Symbol]) - override val phantomRef: NativeResourceRef = super.register(this) + // unable to get the byteAllocated for Symbol + override def bytesAllocated: Long = 0L - override def bytesAllocated: Long = 0 + override def nativeAddress: CPtrAddress = handle + override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxSymbolFree + override val phantomRef: NativeResourceRef = super.register() - private val logger: Logger = LoggerFactory.getLogger(classOf[Symbol]) - private var disposed = false -// protected def isDisposed = disposed - - /** - * Release the native memory. - * The object shall never be used after it is disposed. - */ -/* - def dispose(): Unit = { - if (!disposed) { - _LIB.mxSymbolFree(handle) - disposed = true - } - } -*/ def +(other: Symbol): Symbol = Symbol.createFromListedSymbols("_Plus")(Array(this, other)) def +[@specialized(Int, Float, Double) V](other: V): Symbol = { @@ -844,6 +827,7 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) checkCall(_LIB.mxSymbolSaveToJSON(handle, jsonStr)) jsonStr.value } + } /** From bb369342d1e8bda9360fcd6f668595abca047c97 Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Fri, 7 Sep 2018 00:21:54 -0700 Subject: [PATCH 09/21] remove GCExecutor --- scala-package/core/src/main/scala/org/apache/mxnet/Base.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala index 0cdb492bb8d2..b2a53fd9f2dd 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Base.scala @@ -78,8 +78,6 @@ private[mxnet] object Base { val _LIB = new LibInfo checkCall(_LIB.nativeLibInit()) - val resourceManager = PeriodicGCDeAllocator.createPeriodicGCExecutor() - // TODO: shutdown hook won't work on Windows Runtime.getRuntime.addShutdownHook(new Thread() { override def run(): Unit = { From 9d92dc12d5acb72b5dca20627ffdb550fbf3b782 Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Fri, 7 Sep 2018 09:06:15 -0700 Subject: [PATCH 10/21] deRegister PhantomReferences by when calling dispose() --- .../scala/org/apache/mxnet/NativeResource.scala | 16 +++++++++++++--- .../scala/org/apache/mxnet/ResourceScope.scala | 3 +-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala index edc6901dcc4e..55a2168244ff 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala @@ -85,7 +85,7 @@ private[mxnet] trait NativeResource * Removes this object from PhantomRef tracking and from ResourceScope * @param removeFromScope */ - def deRegister(removeFromScope: Boolean = true): Unit = { + private def deRegister(removeFromScope: Boolean = true): Unit = { NativeResourceRef.deRegister(phantomRef) if (scope != null && removeFromScope) scope.deRegister(this) } @@ -106,6 +106,16 @@ private[mxnet] trait NativeResource print("NativeResource: Disposing NativeResource:%x\n".format(nativeAddress)) checkCall(nativeDeAllocator(this.nativeAddress)) disposed = true + deRegister(true) + } + } + + def dispose(removeFromScope: Boolean): Unit = { + if (!disposed) { + print("NativeResource: Disposing NativeResource:%x\n".format(nativeAddress)) + checkCall(nativeDeAllocator(this.nativeAddress)) + disposed = true + deRegister(removeFromScope) } } } @@ -135,7 +145,7 @@ private[mxnet] object NativeResourceRef { def deRegister(resourceRef: NativeResourceRef): Unit = { val nativeDeAllocator = phantomRefMap.get(resourceRef) - if (nativeDeAllocator != null) { + if (nativeDeAllocator != 0L) { // since CPtrAddress is Scala Long, it cannot be null phantomRefMap.remove(resourceRef) } } @@ -148,7 +158,7 @@ private[mxnet] object NativeResourceRef { // phantomRef will be removed from the map when NativeResource.close is called. val resource = phantomRefMap.get(ref) - if (resource != null) { + if (resource != 0L) { // since CPtrAddress is Scala Long, it cannot be null print("NativeResourceRef: got a reference for resource\n") ref.resourceDeAllocator(resource) phantomRefMap.remove(ref) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala index eb598405adc2..873aa56082d6 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala @@ -29,8 +29,7 @@ class ResourceScope extends AutoCloseable { override def close(): Unit = { resourceQ.foreach(resource => if (resource != null) { logger.info("releasing resource:%x\n".format(resource.nativeAddress)) - resource.dispose() - resource.deRegister(false) + resource.dispose(false) } else {logger.info("found resource which is null")} ) ResourceScope.resourceScope.get().-=(this) From 14657171c6011da9becedddf86ba2d4fff09bb38 Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Fri, 7 Sep 2018 09:37:53 -0700 Subject: [PATCH 11/21] add Finalizer(temporary) to NativeResource --- .../src/main/scala/org/apache/mxnet/NativeResource.scala | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala index 55a2168244ff..6b34e072721b 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala @@ -34,6 +34,7 @@ import org.slf4j.{Logger, LoggerFactory} * NativeResource also implements AutoCloseable so MXNetObjects * can be used like Resources in try-with-resources paradigm */ +// scalastyle:off finalize private[mxnet] trait NativeResource extends AutoCloseable with WarnIfNotDisposed { @@ -118,7 +119,15 @@ private[mxnet] trait NativeResource deRegister(removeFromScope) } } + + override protected def finalize(): Unit = { + if (!isDisposed) { + print("LEAK: %x\n".format(this.nativeAddress)) + super.finalize() + } + } } +// scalastyle:on finalize // do not make nativeRes a member, this will hold reference and GC will not clear the object. private[mxnet] class NativeResourceRef(resource: NativeResource, From e9b4b705835e239a7026dbcf8b280f91ae98e484 Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Fri, 7 Sep 2018 15:42:26 -0700 Subject: [PATCH 12/21] refactor NativeResource.dispose() method --- .../src/main/scala/org/apache/mxnet/NativeResource.scala | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala index 6b34e072721b..8a06a96c9a8e 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala @@ -103,20 +103,15 @@ private[mxnet] trait NativeResource * Implements {@link WarnIfNotDisposed.dispose} */ def dispose(): Unit = { - if (!disposed) { - print("NativeResource: Disposing NativeResource:%x\n".format(nativeAddress)) - checkCall(nativeDeAllocator(this.nativeAddress)) - disposed = true - deRegister(true) - } + dispose(true) } def dispose(removeFromScope: Boolean): Unit = { if (!disposed) { print("NativeResource: Disposing NativeResource:%x\n".format(nativeAddress)) checkCall(nativeDeAllocator(this.nativeAddress)) - disposed = true deRegister(removeFromScope) + disposed = true } } From dd294f06f8f058a4002ac9bd246d354f5d13708d Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Fri, 21 Sep 2018 17:46:45 +0100 Subject: [PATCH 13/21] update NativeResource/add Unit Test for NativeResource --- scala-package/core/pom.xml | 7 ++ .../scala/org/apache/mxnet/Executor.scala | 4 +- .../src/main/scala/org/apache/mxnet/IO.scala | 15 ++- .../main/scala/org/apache/mxnet/NDArray.scala | 4 +- .../org/apache/mxnet/NativeResource.scala | 96 ++++++++----------- .../org/apache/mxnet/ResourceScope.scala | 7 +- .../main/scala/org/apache/mxnet/Symbol.scala | 4 +- .../org/apache/mxnet/io/MXDataIter.scala | 19 ++-- .../org/apache/mxnet/DataBatchSuite.scala | 63 ++++++++++++ .../apache/mxnet/NativeResourceSuite.scala | 69 +++++++++++++ 10 files changed, 208 insertions(+), 80 deletions(-) create mode 100644 scala-package/core/src/test/scala/org/apache/mxnet/DataBatchSuite.scala create mode 100644 scala-package/core/src/test/scala/org/apache/mxnet/NativeResourceSuite.scala diff --git a/scala-package/core/pom.xml b/scala-package/core/pom.xml index ea3a2d68c9f4..e93169f08faa 100644 --- a/scala-package/core/pom.xml +++ b/scala-package/core/pom.xml @@ -123,5 +123,12 @@ commons-io 2.1 + + + org.mockito + mockito-all + 1.10.19 + test + diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala index def97327a2e9..5b7941d8557f 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala @@ -59,10 +59,10 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, private[mxnet] var _group2ctx: Map[String, Context] = null private val logger: Logger = LoggerFactory.getLogger(classOf[Executor]) - override def nativeAddress: CPtrAddress = handle + override def nativeResource: CPtrAddress = handle override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxExecutorFree // cannot determine the off-heap size of this object - override def bytesAllocated: Long = 0 + override val bytesAllocated: Long = 0 override val phantomRef: NativeResourceRef = super.register() /** diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala index e8351422c488..36761fae8e0f 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala @@ -144,7 +144,8 @@ class DataBatch(val data: IndexedSeq[NDArray], // use DataDesc to indicate the order of data/label loading // (must match the order of input data/label) private val providedDataDesc: IndexedSeq[DataDesc], - private val providedLabelDesc: IndexedSeq[DataDesc]) { + private val providedLabelDesc: IndexedSeq[DataDesc]) extends + NativeResource { // TODO: change the data/label type into IndexedSeq[(NDArray, DataDesc)] // However, since the data and label can be accessed publicly (no getter and setter) // the change on this will break BC @@ -162,17 +163,26 @@ class DataBatch(val data: IndexedSeq[NDArray], this(data, label, index, pad, bucketKey, DataDesc.ListMap2Descs(providedData), DataDesc.ListMap2Descs(providedLabel)) } + + // overriding here so DataBatch gets added to Scope and can be disposed + override def nativeResource: CPtrAddress = 0 + override def nativeDeAllocator: CPtrAddress => MXUint = doNothingDeAllocator + def doNothingDeAllocator(x: CPtrAddress): MXUint = {0} + override val phantomRef: NativeResourceRef = super.register() + override val bytesAllocated: DataIterCreator = 0 + /** * Dispose its data and labels * The object shall never be used after it is disposed. */ - def dispose(): Unit = { + override def dispose(): Unit = { if (data != null) { data.foreach(arr => if (arr != null) arr.dispose()) } if (label != null) { label.foreach(arr => if (arr != null) arr.dispose()) } + super.dispose() } // The name and shape of data @@ -198,7 +208,6 @@ class DataBatch(val data: IndexedSeq[NDArray], def provideDataDesc: IndexedSeq[DataDesc] = providedDataDesc def provideLabelDesc: IndexedSeq[DataDesc] = providedLabelDesc - } object DataBatch { diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 77b1f60aba3a..bd9d5a2b3614 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -567,9 +567,9 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, NDArrayCollector.collect(this) } - override def nativeAddress: CPtrAddress = handle + override def nativeResource: CPtrAddress = handle override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxNDArrayFree - override def bytesAllocated: Long = DType.numOfBytes(this.dtype) * this.shape.product + override val bytesAllocated: Long = DType.numOfBytes(this.dtype) * this.shape.product override val phantomRef: NativeResourceRef = super.register() diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala index 8a06a96c9a8e..ac5e4e778843 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala @@ -22,15 +22,16 @@ import java.lang.ref.{PhantomReference, ReferenceQueue, WeakReference} import java.util.concurrent._ import org.apache.mxnet.Base.checkCall -import java.lang.{AutoCloseable, ThreadLocal} import java.util.concurrent.atomic.AtomicLong + +import org.apache.mxnet.NativeResourceRef.phantomRefMap import org.slf4j.{Logger, LoggerFactory} /** * NativeResource trait is used to manage MXNet Objects * such as NDArray, Symbol, Executor, etc., * The MXNet Object calls {@link NativeResource.register} - * and assign the returned NativeResourceRef to {@link phantomRef} + * and assign the returned NativeResourceRef to {@link PhantomReference} * NativeResource also implements AutoCloseable so MXNetObjects * can be used like Resources in try-with-resources paradigm */ @@ -40,13 +41,11 @@ private[mxnet] trait NativeResource /** * native Address associated with this object - * @return */ - def nativeAddress: CPtrAddress + def nativeResource: CPtrAddress /** * Function Pointer to the NativeDeAllocator of {@link nativeAddress} - * @return */ def nativeDeAllocator: (CPtrAddress => Int) @@ -56,9 +55,9 @@ private[mxnet] trait NativeResource /** * Off-Heap Bytes Allocated for this object - * @return */ - def bytesAllocated: Long + // intentionally making it a val, so it gets evaluated when defined + val bytesAllocated: Long private var scope: ResourceScope = null @@ -72,11 +71,10 @@ private[mxnet] trait NativeResource * using PhantomReference */ def register(): NativeResourceRef = { - scope = ResourceScope.getScope() - if (scope != null) { - scope.register(this) - } + if (scope != null) scope.register(this) + + NativeResource.totalBytesAllocated.getAndAdd(bytesAllocated) // register with PhantomRef tracking to release incase the objects go // out of reference within scope but are held for long time NativeResourceRef.register(this, nativeDeAllocator) @@ -91,93 +89,83 @@ private[mxnet] trait NativeResource if (scope != null && removeFromScope) scope.deRegister(this) } - /** - * Implements {@link AutoCloseable.close} - */ + // Implements {@link AutoCloseable.close} override def close(): Unit = { dispose() - deRegister(true) } - /** - * Implements {@link WarnIfNotDisposed.dispose} - */ + // Implements {@link WarnIfNotDisposed.dispose} def dispose(): Unit = { dispose(true) } def dispose(removeFromScope: Boolean): Unit = { if (!disposed) { - print("NativeResource: Disposing NativeResource:%x\n".format(nativeAddress)) - checkCall(nativeDeAllocator(this.nativeAddress)) + print("NativeResource: Disposing NativeResource:%x\n".format(nativeResource)) + checkCall(nativeDeAllocator(this.nativeResource)) deRegister(removeFromScope) + NativeResource.totalBytesAllocated.getAndAdd(-1*bytesAllocated) disposed = true } } - - override protected def finalize(): Unit = { - if (!isDisposed) { - print("LEAK: %x\n".format(this.nativeAddress)) - super.finalize() - } - } } // scalastyle:on finalize -// do not make nativeRes a member, this will hold reference and GC will not clear the object. +private[mxnet] object NativeResource { + var totalBytesAllocated : AtomicLong = new AtomicLong(0) +} +// do not make resource a member, this will hold reference and GC will not clear the object. private[mxnet] class NativeResourceRef(resource: NativeResource, val resourceDeAllocator: CPtrAddress => Int) - extends PhantomReference[NativeResource](resource, NativeResourceRef.referenceQueue) { -} + extends PhantomReference[NativeResource](resource, NativeResourceRef.referenceQueue) {} private[mxnet] object NativeResourceRef { - private val referenceQueue: ReferenceQueue[NativeResource] = new ReferenceQueue[NativeResource] + private[mxnet] val referenceQueue: ReferenceQueue[NativeResource] + = new ReferenceQueue[NativeResource] - private val phantomRefMap = new ConcurrentHashMap[NativeResourceRef, CPtrAddress]() + private[mxnet] val phantomRefMap = new ConcurrentHashMap[NativeResourceRef, CPtrAddress]() - private val cleanupThread = new ResourceCleanupThread() + private[mxnet] val cleaner = new ResourceCleanupThread() - cleanupThread.start() + cleaner.start() def register(resource: NativeResource, nativeDeAllocator: (CPtrAddress => Int)): NativeResourceRef = { val resourceRef = new NativeResourceRef(resource, nativeDeAllocator) - phantomRefMap.put(resourceRef, resource.nativeAddress) + phantomRefMap.put(resourceRef, resource.nativeResource) resourceRef } def deRegister(resourceRef: NativeResourceRef): Unit = { - val nativeDeAllocator = phantomRefMap.get(resourceRef) - if (nativeDeAllocator != 0L) { // since CPtrAddress is Scala Long, it cannot be null + if (phantomRefMap.containsKey(resourceRef)) { phantomRefMap.remove(resourceRef) } } - def cleanup(): Unit = { - print("NativeResourceRef: cleanup\n") - // remove is a blocking call - val ref: NativeResourceRef = referenceQueue.remove().asInstanceOf[NativeResourceRef] - print("NativeResourceRef: got a reference with deAlloc\n") - // phantomRef will be removed from the map when NativeResource.close is called. - val resource = phantomRefMap.get(ref) - - if (resource != 0L) { // since CPtrAddress is Scala Long, it cannot be null - print("NativeResourceRef: got a reference for resource\n") - ref.resourceDeAllocator(resource) - phantomRefMap.remove(ref) - } - } - - private class ResourceCleanupThread extends Thread { + protected class ResourceCleanupThread extends Thread { setPriority(Thread.MAX_PRIORITY) setName("NativeResourceDeAllocatorThread") setDaemon(true) + def deAllocate(): Unit = { + print("NativeResourceRef: cleanup\n") + // remove is a blocking call + val ref: NativeResourceRef = referenceQueue.remove().asInstanceOf[NativeResourceRef] + print("NativeResourceRef: got a reference with deAlloc\n") + // phantomRef will be removed from the map when NativeResource.close is called. + val resource = phantomRefMap.get(ref) + if (resource != 0L) { // since CPtrAddress is Scala Long, it cannot be null + print("NativeResourceRef: got a reference for resource\n") + ref.resourceDeAllocator(resource) + phantomRefMap.remove(ref) + } + } + override def run(): Unit = { while (true) { try { - cleanup() + deAllocate() } catch { case _: InterruptedException => Thread.currentThread().interrupt() diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala index 873aa56082d6..402302446ea0 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala @@ -28,7 +28,7 @@ class ResourceScope extends AutoCloseable { override def close(): Unit = { resourceQ.foreach(resource => if (resource != null) { - logger.info("releasing resource:%x\n".format(resource.nativeAddress)) + logger.info("releasing resource:%x\n".format(resource.nativeResource)) resource.dispose(false) } else {logger.info("found resource which is null")} ) @@ -36,13 +36,12 @@ class ResourceScope extends AutoCloseable { } private[mxnet] def register(resource: NativeResource): Unit = { - logger.info("ResourceScope: Registering Resource %x".format(resource.nativeAddress)) + logger.info("ResourceScope: Registering Resource %x".format(resource.nativeResource)) resourceQ.+=(resource) } - // TODO(@nswamy): this is linear in time, find better data structure private[mxnet] def deRegister(resource: NativeResource): Unit = { - logger.info("ResourceScope: DeRegistering Resource %x".format(resource.nativeAddress)) + logger.info("ResourceScope: DeRegistering Resource %x".format(resource.nativeResource)) resourceQ.-=(resource) } } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala index 7dd86720b6b0..9d1f62b2a921 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala @@ -33,9 +33,9 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeReso private val logger: Logger = LoggerFactory.getLogger(classOf[Symbol]) // unable to get the byteAllocated for Symbol - override def bytesAllocated: Long = 0L + override val bytesAllocated: Long = 0L - override def nativeAddress: CPtrAddress = handle + override def nativeResource: CPtrAddress = handle override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxSymbolFree override val phantomRef: NativeResourceRef = super.register() diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala index f7f858deb82d..04a727c632a9 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala @@ -33,7 +33,7 @@ import scala.collection.mutable.ListBuffer private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, dataName: String = "data", labelName: String = "label") - extends DataIter with WarnIfNotDisposed { + extends DataIter with NativeResource { private val logger = LoggerFactory.getLogger(classOf[MXDataIter]) @@ -67,20 +67,13 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, } } + override def nativeResource: CPtrAddress = handle - private var disposed = false - protected def isDisposed = disposed + override def nativeDeAllocator: CPtrAddress => MXUint = _LIB.mxDataIterFree - /** - * Release the native memory. - * The object shall never be used after it is disposed. - */ - def dispose(): Unit = { - if (!disposed) { - _LIB.mxDataIterFree(handle) - disposed = true - } - } + override val phantomRef: NativeResourceRef = super.register() + + override val bytesAllocated: Long = 0L /** * reset the iterator diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/DataBatchSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/DataBatchSuite.scala new file mode 100644 index 000000000000..667f12625424 --- /dev/null +++ b/scala-package/core/src/test/scala/org/apache/mxnet/DataBatchSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet + +import java.lang.ref.ReferenceQueue +import java.util.concurrent.ConcurrentHashMap + +import org.apache.mxnet +import org.apache.mxnet.Base.CPtrAddress +import org.mockito.Mockito._ +import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers, TagAnnotation} + +@TagAnnotation("resource") +class DataBatchSuite extends FunSuite with BeforeAndAfterAll with Matchers { + + object TestRef { + def getRefQueue: ReferenceQueue[NativeResource] = { NativeResourceRef.referenceQueue} + def getRefMap: ConcurrentHashMap[NativeResourceRef, CPtrAddress] + = {NativeResourceRef.phantomRefMap} + def getCleaner: Thread = { NativeResourceRef.cleaner } + } + + class TestRef(resource: NativeResource, + resourceDeAllocator: CPtrAddress => Int) + extends NativeResourceRef(resource, resourceDeAllocator) { + } + + test(testName = "test DataBatch dispose") { + val dataArray: IndexedSeq[NDArray] + = IndexedSeq.fill[NDArray](10)(NDArray.ones(Shape (3, 4))) + val labelArray = + IndexedSeq.fill[NDArray](10)(NDArray.ones(Shape (1, 2))) + val index = IndexedSeq.fill[Long](10)(0L) + val dBatch: DataBatch = new DataBatch(dataArray, labelArray, index, 0) + val dBatchSpy = spy(dBatch) + + val aRefs = dataArray.map(_.phantomRef) + val batchRef = dBatch.phantomRef + + aRefs.foreach(r => assert(TestRef.getRefMap.containsKey(r) == true)) + assert(TestRef.getRefMap.containsKey(batchRef)) + dBatchSpy.close() + verify(dBatchSpy, times(1)).dispose() + aRefs.foreach(r => assert(TestRef.getRefMap.containsKey(r) == false)) + assert(TestRef.getRefMap.containsKey(batchRef) == false) + } +} + diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NativeResourceSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NativeResourceSuite.scala new file mode 100644 index 000000000000..91a4521d4150 --- /dev/null +++ b/scala-package/core/src/test/scala/org/apache/mxnet/NativeResourceSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet + +import java.lang.ref.ReferenceQueue +import java.util.concurrent.ConcurrentHashMap + +import org.apache.mxnet.Base.CPtrAddress +import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers, TagAnnotation} +import org.mockito.Mockito._ +import scala.ref.PhantomReference + +@TagAnnotation("resource") +class NativeResourceSuite extends FunSuite with BeforeAndAfterAll with Matchers { + + object TestRef { + def getRefQueue: ReferenceQueue[NativeResource] = { NativeResourceRef.referenceQueue} + def getRefMap: ConcurrentHashMap[NativeResourceRef, CPtrAddress] + = {NativeResourceRef.phantomRefMap} + def getCleaner: Thread = { NativeResourceRef.cleaner } + } + + class TestRef(resource: NativeResource, + resourceDeAllocator: CPtrAddress => Int) + extends NativeResourceRef(resource, resourceDeAllocator) { + } + + test(testName = "test native resource setup/teardown") { + val a = spy(NDArray.ones(Shape(2, 3))) + val aRef = a.phantomRef + val spyRef = spy(aRef) + + assert(TestRef.getRefMap.containsKey(aRef) == true) + a.close() + verify(a).dispose() + verify(a).nativeDeAllocator + // resourceDeAllocator does not get called when explicitly closing + verify(spyRef, times(0)).resourceDeAllocator + + assert(TestRef.getRefMap.containsKey(aRef) == false) + assert(a.isDisposed == true, "isDisposed should be set to true after calling close") + } + + test(testName = "test dispose") { + val a: NDArray = NDArray.ones(Shape(3, 4)) + val aRef = a.phantomRef + val spyRef = spy(aRef) + a.dispose() + verify(spyRef).resourceDeAllocator + assert(TestRef.getRefMap.containsKey(aRef) == false) + assert(a.isDisposed == true, "isDisposed should be set to true after calling close") + } +} + From 980db5a44e84f92357d39b296e2a4e4b5197c6f8 Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Mon, 24 Sep 2018 07:37:21 -0400 Subject: [PATCH 14/21] updates to NativeResource/NativeResourceRef and unit tests to NativeResource --- .../scala/org/apache/mxnet/Executor.scala | 4 +- .../src/main/scala/org/apache/mxnet/IO.scala | 15 +-- .../main/scala/org/apache/mxnet/NDArray.scala | 4 +- .../org/apache/mxnet/NativeResource.scala | 105 ++++++++++-------- .../org/apache/mxnet/ResourceScope.scala | 6 +- .../main/scala/org/apache/mxnet/Symbol.scala | 5 +- .../org/apache/mxnet/io/MXDataIter.scala | 4 +- .../org/apache/mxnet/DataBatchSuite.scala | 63 ----------- .../apache/mxnet/NativeResourceSuite.scala | 23 ++-- 9 files changed, 91 insertions(+), 138 deletions(-) delete mode 100644 scala-package/core/src/test/scala/org/apache/mxnet/DataBatchSuite.scala diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala index 5b7941d8557f..581109e3ba07 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala @@ -59,11 +59,11 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, private[mxnet] var _group2ctx: Map[String, Context] = null private val logger: Logger = LoggerFactory.getLogger(classOf[Executor]) - override def nativeResource: CPtrAddress = handle + override def nativeAddress: CPtrAddress = handle override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxExecutorFree // cannot determine the off-heap size of this object override val bytesAllocated: Long = 0 - override val phantomRef: NativeResourceRef = super.register() + override val ref: NativeResourceRef = super.register() /** * Return a new executor with the same symbol and shared memory, diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala index 36761fae8e0f..e8351422c488 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala @@ -144,8 +144,7 @@ class DataBatch(val data: IndexedSeq[NDArray], // use DataDesc to indicate the order of data/label loading // (must match the order of input data/label) private val providedDataDesc: IndexedSeq[DataDesc], - private val providedLabelDesc: IndexedSeq[DataDesc]) extends - NativeResource { + private val providedLabelDesc: IndexedSeq[DataDesc]) { // TODO: change the data/label type into IndexedSeq[(NDArray, DataDesc)] // However, since the data and label can be accessed publicly (no getter and setter) // the change on this will break BC @@ -163,26 +162,17 @@ class DataBatch(val data: IndexedSeq[NDArray], this(data, label, index, pad, bucketKey, DataDesc.ListMap2Descs(providedData), DataDesc.ListMap2Descs(providedLabel)) } - - // overriding here so DataBatch gets added to Scope and can be disposed - override def nativeResource: CPtrAddress = 0 - override def nativeDeAllocator: CPtrAddress => MXUint = doNothingDeAllocator - def doNothingDeAllocator(x: CPtrAddress): MXUint = {0} - override val phantomRef: NativeResourceRef = super.register() - override val bytesAllocated: DataIterCreator = 0 - /** * Dispose its data and labels * The object shall never be used after it is disposed. */ - override def dispose(): Unit = { + def dispose(): Unit = { if (data != null) { data.foreach(arr => if (arr != null) arr.dispose()) } if (label != null) { label.foreach(arr => if (arr != null) arr.dispose()) } - super.dispose() } // The name and shape of data @@ -208,6 +198,7 @@ class DataBatch(val data: IndexedSeq[NDArray], def provideDataDesc: IndexedSeq[DataDesc] = providedDataDesc def provideLabelDesc: IndexedSeq[DataDesc] = providedLabelDesc + } object DataBatch { diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index bd9d5a2b3614..f039a0d88ab3 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -567,11 +567,11 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, NDArrayCollector.collect(this) } - override def nativeResource: CPtrAddress = handle + override def nativeAddress: CPtrAddress = handle override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxNDArrayFree override val bytesAllocated: Long = DType.numOfBytes(this.dtype) * this.shape.product - override val phantomRef: NativeResourceRef = super.register() + override val ref: NativeResourceRef = super.register() // record arrays who construct this array instance // we use weak reference to prevent gc blocking diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala index ac5e4e778843..7cf63838240b 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala @@ -24,34 +24,31 @@ import java.util.concurrent._ import org.apache.mxnet.Base.checkCall import java.util.concurrent.atomic.AtomicLong -import org.apache.mxnet.NativeResourceRef.phantomRefMap -import org.slf4j.{Logger, LoggerFactory} /** * NativeResource trait is used to manage MXNet Objects * such as NDArray, Symbol, Executor, etc., - * The MXNet Object calls {@link NativeResource.register} - * and assign the returned NativeResourceRef to {@link PhantomReference} + * The MXNet Object calls NativeResource.register + * and assign the returned NativeResourceRef to PhantomReference * NativeResource also implements AutoCloseable so MXNetObjects * can be used like Resources in try-with-resources paradigm */ -// scalastyle:off finalize private[mxnet] trait NativeResource extends AutoCloseable with WarnIfNotDisposed { /** * native Address associated with this object */ - def nativeResource: CPtrAddress + def nativeAddress: CPtrAddress /** - * Function Pointer to the NativeDeAllocator of {@link nativeAddress} + * Function Pointer to the NativeDeAllocator of nativeAddress */ def nativeDeAllocator: (CPtrAddress => Int) - /** Call {@link NativeResource.register} to get {@link NativeResourceRef} + /** Call NativeResource.register to get the reference */ - val phantomRef: NativeResourceRef + val ref: NativeResourceRef /** * Off-Heap Bytes Allocated for this object @@ -59,11 +56,12 @@ private[mxnet] trait NativeResource // intentionally making it a val, so it gets evaluated when defined val bytesAllocated: Long - private var scope: ResourceScope = null + private[mxnet] var scope: ResourceScope = null @volatile var disposed = false - override def isDisposed: Boolean = disposed + override def isDisposed: Boolean = disposed || isDeAllocated + /** * Register this object for PhantomReference tracking and within * ResourceScope if used inside ResourceScope. @@ -80,12 +78,9 @@ private[mxnet] trait NativeResource NativeResourceRef.register(this, nativeDeAllocator) } - /** - * Removes this object from PhantomRef tracking and from ResourceScope - * @param removeFromScope - */ + // Removes this object from PhantomRef tracking and from ResourceScope private def deRegister(removeFromScope: Boolean = true): Unit = { - NativeResourceRef.deRegister(phantomRef) + NativeResourceRef.deRegister(ref) if (scope != null && removeFromScope) scope.deRegister(this) } @@ -101,30 +96,40 @@ private[mxnet] trait NativeResource def dispose(removeFromScope: Boolean): Unit = { if (!disposed) { - print("NativeResource: Disposing NativeResource:%x\n".format(nativeResource)) - checkCall(nativeDeAllocator(this.nativeResource)) + print("NativeResource: Disposing NativeResource:%x\n".format(nativeAddress)) + checkCall(nativeDeAllocator(this.nativeAddress)) deRegister(removeFromScope) NativeResource.totalBytesAllocated.getAndAdd(-1*bytesAllocated) disposed = true } } + + /* + this is used by the WarnIfNotDisposed finalizer, + the object could be disposed by the GC without the need for explicit disposal + but the finalizer might not have run, then the WarnIfNotDisposed throws a warning + */ + private def isDeAllocated(): Boolean = NativeResourceRef.isDeAllocated(ref) + } -// scalastyle:on finalize private[mxnet] object NativeResource { var totalBytesAllocated : AtomicLong = new AtomicLong(0) } -// do not make resource a member, this will hold reference and GC will not clear the object. + +/* Do not make resource a member of the class, +this will hold reference and GC will not clear the object. + */ private[mxnet] class NativeResourceRef(resource: NativeResource, val resourceDeAllocator: CPtrAddress => Int) - extends PhantomReference[NativeResource](resource, NativeResourceRef.referenceQueue) {} + extends PhantomReference[NativeResource](resource, NativeResourceRef.refQ) {} private[mxnet] object NativeResourceRef { - private[mxnet] val referenceQueue: ReferenceQueue[NativeResource] + private[mxnet] val refQ: ReferenceQueue[NativeResource] = new ReferenceQueue[NativeResource] - private[mxnet] val phantomRefMap = new ConcurrentHashMap[NativeResourceRef, CPtrAddress]() + private[mxnet] val refMap = new ConcurrentHashMap[NativeResourceRef, CPtrAddress]() private[mxnet] val cleaner = new ResourceCleanupThread() @@ -132,14 +137,40 @@ private[mxnet] object NativeResourceRef { def register(resource: NativeResource, nativeDeAllocator: (CPtrAddress => Int)): NativeResourceRef = { - val resourceRef = new NativeResourceRef(resource, nativeDeAllocator) - phantomRefMap.put(resourceRef, resource.nativeResource) - resourceRef + val ref = new NativeResourceRef(resource, nativeDeAllocator) + refMap.put(ref, resource.nativeAddress) + ref } - def deRegister(resourceRef: NativeResourceRef): Unit = { - if (phantomRefMap.containsKey(resourceRef)) { - phantomRefMap.remove(resourceRef) + def deRegister(ref: NativeResourceRef): Unit = { + if (refMap.containsKey(ref)) { + refMap.remove(ref) + } + } + + /** + * This method will check if the cleaner ran and deAllocated the object + * As a part of GC, when the object is unreachable GC inserts a phantomRef + * to the ReferenceQueue which the cleaner thread will deallocate, however + * the finalizer runs much later depending on the GC. + * @param resource resource to verify if it has been deAllocated + * @return true if already deAllocated + */ + def isDeAllocated(ref: NativeResourceRef): Boolean = { + !refMap.containsKey(ref) + } + + def cleanup: Unit = { + print("NativeResourceRef: cleanup\n") + // remove is a blocking call + val ref: NativeResourceRef = refQ.remove().asInstanceOf[NativeResourceRef] + print("NativeResourceRef: got a reference with deAlloc\n") + // phantomRef will be removed from the map when NativeResource.close is called. + val resource = refMap.get(ref) + if (resource != 0L) { // since CPtrAddress is Scala a Long, it cannot be null + print("NativeResourceRef: got a reference for resource\n") + ref.resourceDeAllocator(resource) + refMap.remove(ref) } } @@ -148,24 +179,10 @@ private[mxnet] object NativeResourceRef { setName("NativeResourceDeAllocatorThread") setDaemon(true) - def deAllocate(): Unit = { - print("NativeResourceRef: cleanup\n") - // remove is a blocking call - val ref: NativeResourceRef = referenceQueue.remove().asInstanceOf[NativeResourceRef] - print("NativeResourceRef: got a reference with deAlloc\n") - // phantomRef will be removed from the map when NativeResource.close is called. - val resource = phantomRefMap.get(ref) - if (resource != 0L) { // since CPtrAddress is Scala Long, it cannot be null - print("NativeResourceRef: got a reference for resource\n") - ref.resourceDeAllocator(resource) - phantomRefMap.remove(ref) - } - } - override def run(): Unit = { while (true) { try { - deAllocate() + NativeResourceRef.cleanup } catch { case _: InterruptedException => Thread.currentThread().interrupt() diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala index 402302446ea0..6e134bcaeff3 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala @@ -28,7 +28,7 @@ class ResourceScope extends AutoCloseable { override def close(): Unit = { resourceQ.foreach(resource => if (resource != null) { - logger.info("releasing resource:%x\n".format(resource.nativeResource)) + logger.info("releasing resource:%x\n".format(resource.nativeAddress)) resource.dispose(false) } else {logger.info("found resource which is null")} ) @@ -36,12 +36,12 @@ class ResourceScope extends AutoCloseable { } private[mxnet] def register(resource: NativeResource): Unit = { - logger.info("ResourceScope: Registering Resource %x".format(resource.nativeResource)) + logger.info("ResourceScope: Registering Resource %x".format(resource.nativeAddress)) resourceQ.+=(resource) } private[mxnet] def deRegister(resource: NativeResource): Unit = { - logger.info("ResourceScope: DeRegistering Resource %x".format(resource.nativeResource)) + logger.info("ResourceScope: DeRegistering Resource %x".format(resource.nativeAddress)) resourceQ.-=(resource) } } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala index 9d1f62b2a921..b45f9dcca465 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala @@ -34,10 +34,9 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeReso // unable to get the byteAllocated for Symbol override val bytesAllocated: Long = 0L - - override def nativeResource: CPtrAddress = handle + override def nativeAddress: CPtrAddress = handle override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxSymbolFree - override val phantomRef: NativeResourceRef = super.register() + override val ref: NativeResourceRef = super.register() def +(other: Symbol): Symbol = Symbol.createFromListedSymbols("_Plus")(Array(this, other)) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala index 04a727c632a9..998017750db2 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala @@ -67,11 +67,11 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle, } } - override def nativeResource: CPtrAddress = handle + override def nativeAddress: CPtrAddress = handle override def nativeDeAllocator: CPtrAddress => MXUint = _LIB.mxDataIterFree - override val phantomRef: NativeResourceRef = super.register() + override val ref: NativeResourceRef = super.register() override val bytesAllocated: Long = 0L diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/DataBatchSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/DataBatchSuite.scala deleted file mode 100644 index 667f12625424..000000000000 --- a/scala-package/core/src/test/scala/org/apache/mxnet/DataBatchSuite.scala +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mxnet - -import java.lang.ref.ReferenceQueue -import java.util.concurrent.ConcurrentHashMap - -import org.apache.mxnet -import org.apache.mxnet.Base.CPtrAddress -import org.mockito.Mockito._ -import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers, TagAnnotation} - -@TagAnnotation("resource") -class DataBatchSuite extends FunSuite with BeforeAndAfterAll with Matchers { - - object TestRef { - def getRefQueue: ReferenceQueue[NativeResource] = { NativeResourceRef.referenceQueue} - def getRefMap: ConcurrentHashMap[NativeResourceRef, CPtrAddress] - = {NativeResourceRef.phantomRefMap} - def getCleaner: Thread = { NativeResourceRef.cleaner } - } - - class TestRef(resource: NativeResource, - resourceDeAllocator: CPtrAddress => Int) - extends NativeResourceRef(resource, resourceDeAllocator) { - } - - test(testName = "test DataBatch dispose") { - val dataArray: IndexedSeq[NDArray] - = IndexedSeq.fill[NDArray](10)(NDArray.ones(Shape (3, 4))) - val labelArray = - IndexedSeq.fill[NDArray](10)(NDArray.ones(Shape (1, 2))) - val index = IndexedSeq.fill[Long](10)(0L) - val dBatch: DataBatch = new DataBatch(dataArray, labelArray, index, 0) - val dBatchSpy = spy(dBatch) - - val aRefs = dataArray.map(_.phantomRef) - val batchRef = dBatch.phantomRef - - aRefs.foreach(r => assert(TestRef.getRefMap.containsKey(r) == true)) - assert(TestRef.getRefMap.containsKey(batchRef)) - dBatchSpy.close() - verify(dBatchSpy, times(1)).dispose() - aRefs.foreach(r => assert(TestRef.getRefMap.containsKey(r) == false)) - assert(TestRef.getRefMap.containsKey(batchRef) == false) - } -} - diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NativeResourceSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NativeResourceSuite.scala index 91a4521d4150..58b94ecb02f7 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/NativeResourceSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/NativeResourceSuite.scala @@ -21,17 +21,17 @@ import java.lang.ref.ReferenceQueue import java.util.concurrent.ConcurrentHashMap import org.apache.mxnet.Base.CPtrAddress +import org.mockito.Matchers.any import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers, TagAnnotation} import org.mockito.Mockito._ -import scala.ref.PhantomReference @TagAnnotation("resource") class NativeResourceSuite extends FunSuite with BeforeAndAfterAll with Matchers { object TestRef { - def getRefQueue: ReferenceQueue[NativeResource] = { NativeResourceRef.referenceQueue} + def getRefQueue: ReferenceQueue[NativeResource] = { NativeResourceRef.refQ} def getRefMap: ConcurrentHashMap[NativeResourceRef, CPtrAddress] - = {NativeResourceRef.phantomRefMap} + = {NativeResourceRef.refMap} def getCleaner: Thread = { NativeResourceRef.cleaner } } @@ -42,7 +42,7 @@ class NativeResourceSuite extends FunSuite with BeforeAndAfterAll with Matchers test(testName = "test native resource setup/teardown") { val a = spy(NDArray.ones(Shape(2, 3))) - val aRef = a.phantomRef + val aRef = a.ref val spyRef = spy(aRef) assert(TestRef.getRefMap.containsKey(aRef) == true) @@ -57,13 +57,22 @@ class NativeResourceSuite extends FunSuite with BeforeAndAfterAll with Matchers } test(testName = "test dispose") { - val a: NDArray = NDArray.ones(Shape(3, 4)) - val aRef = a.phantomRef + val a: NDArray = spy(NDArray.ones(Shape(3, 4))) + val aRef = a.ref val spyRef = spy(aRef) a.dispose() - verify(spyRef).resourceDeAllocator + verify(a).nativeDeAllocator assert(TestRef.getRefMap.containsKey(aRef) == false) assert(a.isDisposed == true, "isDisposed should be set to true after calling close") } + + test(testName = "test dispose not removing from resourceScope") { + val a: NDArray = spy(NDArray.ones(Shape(3, 4))) + val r: ResourceScope = mock(classOf[ResourceScope]) + when(a.scope).thenReturn(r) + a.dispose(false) + verify(r, times(0)).deRegister(any[NativeResource]) + } + } From 2b0b073f47de807e2bf536b8cf7b1fcb70777f8a Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Fri, 12 Oct 2018 15:56:28 -0700 Subject: [PATCH 15/21] remove redundant code added because of the object equality that was needed --- .../org/apache/mxnet/NativeResource.scala | 20 ++++++++----------- .../apache/mxnet/NativeResourceSuite.scala | 9 --------- 2 files changed, 8 insertions(+), 21 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala index 7cf63838240b..1017f03a6361 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala @@ -63,14 +63,14 @@ private[mxnet] trait NativeResource override def isDisposed: Boolean = disposed || isDeAllocated /** - * Register this object for PhantomReference tracking and within + * Register this object for PhantomReference tracking and in * ResourceScope if used inside ResourceScope. * @return NativeResourceRef that tracks reachability of this object * using PhantomReference */ def register(): NativeResourceRef = { - scope = ResourceScope.getScope() - if (scope != null) scope.register(this) + scope = ResourceScope.getCurrentScope() + if (scope != null) scope.add(this) NativeResource.totalBytesAllocated.getAndAdd(bytesAllocated) // register with PhantomRef tracking to release incase the objects go @@ -79,26 +79,22 @@ private[mxnet] trait NativeResource } // Removes this object from PhantomRef tracking and from ResourceScope - private def deRegister(removeFromScope: Boolean = true): Unit = { + private def deRegister(): Unit = { NativeResourceRef.deRegister(ref) - if (scope != null && removeFromScope) scope.deRegister(this) + if (scope != null) scope.remove(this) } - // Implements {@link AutoCloseable.close} + // Implements [[@link AutoCloseable.close]] override def close(): Unit = { dispose() } - // Implements {@link WarnIfNotDisposed.dispose} + // Implements [[@link WarnIfNotDisposed.dispose]] def dispose(): Unit = { - dispose(true) - } - - def dispose(removeFromScope: Boolean): Unit = { if (!disposed) { print("NativeResource: Disposing NativeResource:%x\n".format(nativeAddress)) checkCall(nativeDeAllocator(this.nativeAddress)) - deRegister(removeFromScope) + deRegister() NativeResource.totalBytesAllocated.getAndAdd(-1*bytesAllocated) disposed = true } diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/NativeResourceSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/NativeResourceSuite.scala index 58b94ecb02f7..81a9f605a887 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/NativeResourceSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/NativeResourceSuite.scala @@ -65,14 +65,5 @@ class NativeResourceSuite extends FunSuite with BeforeAndAfterAll with Matchers assert(TestRef.getRefMap.containsKey(aRef) == false) assert(a.isDisposed == true, "isDisposed should be set to true after calling close") } - - test(testName = "test dispose not removing from resourceScope") { - val a: NDArray = spy(NDArray.ones(Shape(3, 4))) - val r: ResourceScope = mock(classOf[ResourceScope]) - when(a.scope).thenReturn(r) - a.dispose(false) - verify(r, times(0)).deRegister(any[NativeResource]) - } - } From 18b11751832b337992dcc081f7d4bdb8b03c42be Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Fri, 12 Oct 2018 15:56:59 -0700 Subject: [PATCH 16/21] add ResourceScope --- .../org/apache/mxnet/ResourceScope.scala | 196 ++++++++++++++---- .../org/apache/mxnet/ResourceScopeSuite.scala | 68 ++++++ 2 files changed, 228 insertions(+), 36 deletions(-) create mode 100644 scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala index 6e134bcaeff3..2511da0e6e32 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala @@ -18,31 +18,66 @@ package org.apache.mxnet import org.slf4j.LoggerFactory + import scala.collection.mutable.ArrayBuffer +import scala.util.Try +import scala.util.control.{ControlThrowable, NonFatal} +import java.util.Comparator +/** + * This class manages releasing of Nativeresources + */ class ResourceScope extends AutoCloseable { - import ResourceScope.{logger, resourceScope} - private val resourceQ = new ArrayBuffer[NativeResource]() - resourceScope.get().+=(this) + import ResourceScope.{logger, threadLocalScopes, addToLocalScope, removeFromLocalScope} + private[mxnet] val resourceQ = new ArrayBuffer[NativeResource] { + // this override is required for object equality check instead of content equality + override def indexOf[B >: NativeResource](elem: B, from: Int): Int = { + indexWhere(elem.asInstanceOf[NativeResource].nativeAddress == + _.nativeAddress, from) + } + override def lastIndexOf[B >: NativeResource](elem: B): Int = { + lastIndexWhere(elem.asInstanceOf[NativeResource].nativeAddress == _.nativeAddress) + } + } + + ResourceScope.addToLocalScope(this) + + /** + * Releases all the [[NativeResource]] by calling + * the associated [[NativeResource.dispose()]] method + */ override def close(): Unit = { resourceQ.foreach(resource => if (resource != null) { logger.info("releasing resource:%x\n".format(resource.nativeAddress)) - resource.dispose(false) - } else {logger.info("found resource which is null")} + resource.close() + } else { + logger.info("found resource which is null") + } ) - ResourceScope.resourceScope.get().-=(this) + resourceQ.clear() + ResourceScope.removeFromLocalScope(this) } - private[mxnet] def register(resource: NativeResource): Unit = { + /** + * Add a Native Resource to the scope + * @param resource + */ + private[mxnet] def add(resource: NativeResource): Unit = { logger.info("ResourceScope: Registering Resource %x".format(resource.nativeAddress)) resourceQ.+=(resource) } - private[mxnet] def deRegister(resource: NativeResource): Unit = { + /** + * Remove Native Resource from the Scope, this uses + * object equality to find the resource in the stack. + * @param resource + */ + private[mxnet] def remove(resource: NativeResource): Unit = { logger.info("ResourceScope: DeRegistering Resource %x".format(resource.nativeAddress)) resourceQ.-=(resource) + logger.info("resourceQ size: %d".format(resourceQ.size)) } } @@ -50,35 +85,124 @@ object ResourceScope { private val logger = LoggerFactory.getLogger(classOf[ResourceScope]) - // inspired from slide 21 of - def using[T](resource: ResourceScope)(block: => T): T = { - require(resource != null) - try { - val ret = block - ret match { - case nRes: NativeResource => - resource.deRegister(nRes.asInstanceOf[NativeResource]) - case _ => // do nothing - } - ret - } finally { - // TODO(nswamy@): handle exceptions - resource.close - } - } - - private[mxnet] val resourceScope = new ThreadLocal[ArrayBuffer[ResourceScope]] { - override def initialValue(): ArrayBuffer[ResourceScope] = - new ArrayBuffer[ResourceScope]() - } - - private[mxnet] def getScope(): ResourceScope = { + /** + * Captures all Native Resources created using the ResourceScope and + * at the end of the body, de allocates all the Native resources by calling close on them. + * + * @param localScope Scope in which to capture the native resources + * @param body block of code to execute in this scope + * @tparam R the type of the resource + * @tparam A return type + * @return result of the operation, if the result is of type NativeResource, it is not + * de allocated so the user can use it and then de allocate manually by calling + * close or enclose in another resourceScope. + */ + // inspired from slide 21 of https://www.slideshare.net/Odersky/fosdem-2009-1013261 + // and https://github.com/scala/scala/blob/2.13.x/src/library/scala/util/Using.scala + // TODO: we should move to the Scala util's Using method when we move to Scala 2.13 + def using[A](scope: ResourceScope = null)(body: => A): A = { + + val curScope = if (scope != null) scope else new ResourceScope() + + val prevScope: Option[ResourceScope] = ResourceScope.getPrevScope() + + @inline def resourceInGeneric(g: scala.collection.Iterable[_]) = { + g.foreach( n => + n match { + case nRes: NativeResource => { + removeAndAddToPrevScope(nRes) + } + case kv: scala.Tuple2[_, _] => { + if (kv._1.isInstanceOf[NativeResource]) removeAndAddToPrevScope( + kv._1.asInstanceOf[NativeResource]) + if (kv._2.isInstanceOf[NativeResource]) removeAndAddToPrevScope( + kv._2.asInstanceOf[NativeResource]) + } + } + ) + } + + @inline def removeAndAddToPrevScope(r: NativeResource) = { + curScope.remove(r) + if (prevScope.isDefined) { + prevScope.get.add(r) + r.scope = prevScope.get + } + } + + @inline def safeAddSuppressed(t: Throwable, suppressed: Throwable): Unit = { + if (!t.isInstanceOf[ControlThrowable]) t.addSuppressed(suppressed) + } + + var retThrowable: Throwable = null + try { - resourceScope.get().last + val ret = body + ret match { + // don't de-allocate if returning any collection that contains NativeResource. + case resInGeneric: scala.collection.Iterable[_] => resourceInGeneric(resInGeneric) + case nRes: NativeResource => removeAndAddToPrevScope(nRes) + case ndRet: NDArrayFuncReturn => ndRet.arr.foreach( nd => removeAndAddToPrevScope(nd) ) + case _ => // do nothing + } + ret } catch { - case _: ArrayIndexOutOfBoundsException => null - case _: NoSuchElementException => null - case e: Exception => throw e + case t: Throwable => + retThrowable = t + null.asInstanceOf[A] // we'll throw in finally + } finally { + var toThrow: Throwable = retThrowable + if (retThrowable eq null) curScope.close() + else { + try { + curScope.close + } catch { + case closeThrowable: Throwable => + if (NonFatal(retThrowable) && !NonFatal(closeThrowable)) toThrow = closeThrowable + else safeAddSuppressed(retThrowable, closeThrowable) + } finally { + throw toThrow + } + } } - } + } + + // thread local Scopes + private[mxnet] val threadLocalScopes = new ThreadLocal[ArrayBuffer[ResourceScope]] { + override def initialValue(): ArrayBuffer[ResourceScope] = + new ArrayBuffer[ResourceScope]() + } + + /** + * Add resource to current ThreadLocal DataStructure + * @param r ResourceScope to add. + */ + private[mxnet] def addToLocalScope(r: ResourceScope): Unit = { + threadLocalScopes.get() += r + } + + /** + * Remove resource from current ThreadLocal DataStructure + * @param r ResourceScope to remove + */ + private[mxnet] def removeFromLocalScope(r: ResourceScope): Unit = { + threadLocalScopes.get() -= r + } + + /** + * Get the latest Scope in the stack + * @return + */ + private[mxnet] def getCurrentScope(): ResourceScope = { + Try(threadLocalScopes.get().last).getOrElse(null) + } + + /** + * Get the Last but one Scope from threadLocal Scopes. + * @return n-1th scope or None when not found + */ + private[mxnet] def getPrevScope(): Option[ResourceScope] = { + val scopes = threadLocalScopes.get() + Try(Some(scopes(scopes.size - 2))).getOrElse(None) + } } \ No newline at end of file diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala new file mode 100644 index 000000000000..54904b9e6763 --- /dev/null +++ b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnet + +import java.lang.ref.ReferenceQueue +import java.util.concurrent.ConcurrentHashMap + +import org.apache.mxnet.Base.CPtrAddress +import org.apache.mxnet.ResourceScope.logger +import org.mockito.Matchers.any +import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} +import org.mockito.Mockito._ + +class ResourceScopeSuite extends FunSuite with BeforeAndAfterAll with Matchers { + + object TestRef { + def getRefQueue: ReferenceQueue[NativeResource] = { NativeResourceRef.refQ} + def getRefMap: ConcurrentHashMap[NativeResourceRef, CPtrAddress] + = {NativeResourceRef.refMap} + def getCleaner: Thread = { NativeResourceRef.cleaner } + } + + class TestRef(resource: NativeResource, + resourceDeAllocator: CPtrAddress => Int) + extends NativeResourceRef(resource, resourceDeAllocator) { + } + + test(testName = "testAutoReleasefromScope") { + var a: NDArray = null + var aRef: NativeResourceRef = null + val b: NDArray = ResourceScope.using() { + a = spy(NDArray.ones(Shape(3, 4))) + print("testAutoReleasefromScope: a address %x\n".format(a.nativeAddress)) + aRef = a.ref + val x = NDArray.ones(Shape(3, 4)) + print("testAutoReleasefromScope: x address %x\n".format(x.nativeAddress)) + x + } + val bRef: NativeResourceRef = b.ref + assert(a.isDisposed == true, "objects created within scope should have isDisposed set to true") + assert(b.isDisposed == false, "returned NativeResource should not be released") + assert(TestRef.getRefMap.containsKey(aRef) == false, + "reference of resource in Scope should be removed refMap") + assert(TestRef.getRefMap.containsKey(bRef) == true, + "reference of resource outside scope should be not removed refMap") + + } + + test("release from outerscope") { + var a: NDArray = null + } + +} From e2d5c9981edced195104d05f01b07a9bded7b22c Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Sat, 13 Oct 2018 22:56:16 -0700 Subject: [PATCH 17/21] Fix NativeResource to not remove from Scope, add Unit Tests to ResourceScope --- .../org/apache/mxnet/NativeResource.scala | 12 +- .../org/apache/mxnet/ResourceScope.scala | 25 ++-- .../org/apache/mxnet/ResourceScopeSuite.scala | 125 +++++++++++++++--- 3 files changed, 121 insertions(+), 41 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala index 1017f03a6361..deecf22e5e65 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala @@ -79,9 +79,9 @@ private[mxnet] trait NativeResource } // Removes this object from PhantomRef tracking and from ResourceScope - private def deRegister(): Unit = { + private def deRegister(removeFromScope: Boolean): Unit = { NativeResourceRef.deRegister(ref) - if (scope != null) scope.remove(this) + if (scope != null && removeFromScope) scope.remove(this) } // Implements [[@link AutoCloseable.close]] @@ -90,11 +90,13 @@ private[mxnet] trait NativeResource } // Implements [[@link WarnIfNotDisposed.dispose]] - def dispose(): Unit = { + def dispose(): Unit = dispose(true) + + private[mxnet] def dispose(removeFromScope: Boolean = true): Unit = { if (!disposed) { print("NativeResource: Disposing NativeResource:%x\n".format(nativeAddress)) checkCall(nativeDeAllocator(this.nativeAddress)) - deRegister() + deRegister(removeFromScope) NativeResource.totalBytesAllocated.getAndAdd(-1*bytesAllocated) disposed = true } @@ -105,7 +107,7 @@ private[mxnet] trait NativeResource the object could be disposed by the GC without the need for explicit disposal but the finalizer might not have run, then the WarnIfNotDisposed throws a warning */ - private def isDeAllocated(): Boolean = NativeResourceRef.isDeAllocated(ref) + private[mxnet] def isDeAllocated(): Boolean = NativeResourceRef.isDeAllocated(ref) } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala index 2511da0e6e32..ceb84f1b67a6 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala @@ -25,7 +25,7 @@ import scala.util.control.{ControlThrowable, NonFatal} import java.util.Comparator /** - * This class manages releasing of Nativeresources + * This class manages automatically releasing of [[NativeResource]]s */ class ResourceScope extends AutoCloseable { @@ -46,35 +46,29 @@ class ResourceScope extends AutoCloseable { /** * Releases all the [[NativeResource]] by calling - * the associated [[NativeResource.dispose()]] method + * the associated [[NativeResource.close()]] method */ override def close(): Unit = { - resourceQ.foreach(resource => if (resource != null) { - logger.info("releasing resource:%x\n".format(resource.nativeAddress)) - resource.close() - } else { - logger.info("found resource which is null") - } - ) + resourceQ.foreach(resource => if (resource != null) resource.dispose(false) ) resourceQ.clear() ResourceScope.removeFromLocalScope(this) } /** - * Add a Native Resource to the scope + * Add a NativeResource to the scope * @param resource */ - private[mxnet] def add(resource: NativeResource): Unit = { + def add(resource: NativeResource): Unit = { logger.info("ResourceScope: Registering Resource %x".format(resource.nativeAddress)) resourceQ.+=(resource) } /** - * Remove Native Resource from the Scope, this uses + * Remove NativeResource from the Scope, this uses * object equality to find the resource in the stack. * @param resource */ - private[mxnet] def remove(resource: NativeResource): Unit = { + def remove(resource: NativeResource): Unit = { logger.info("ResourceScope: DeRegistering Resource %x".format(resource.nativeAddress)) resourceQ.-=(resource) logger.info("resourceQ size: %d".format(resourceQ.size)) @@ -88,10 +82,9 @@ object ResourceScope { /** * Captures all Native Resources created using the ResourceScope and * at the end of the body, de allocates all the Native resources by calling close on them. - * - * @param localScope Scope in which to capture the native resources + * This method will not deAllocate NativeResources returned from the block. + * @param scope Scope in which to capture the native resources * @param body block of code to execute in this scope - * @tparam R the type of the resource * @tparam A return type * @return result of the operation, if the result is of type NativeResource, it is not * de allocated so the user can use it and then de allocate manually by calling diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala index 54904b9e6763..a57babd5a601 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala @@ -25,44 +25,129 @@ import org.apache.mxnet.ResourceScope.logger import org.mockito.Matchers.any import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers} import org.mockito.Mockito._ +import scala.collection.mutable.HashMap class ResourceScopeSuite extends FunSuite with BeforeAndAfterAll with Matchers { - object TestRef { + class TestNativeResource extends NativeResource { + /** + * native Address associated with this object + */ + override def nativeAddress: CPtrAddress = hashCode() + + /** + * Function Pointer to the NativeDeAllocator of nativeAddress + */ + override def nativeDeAllocator: CPtrAddress => Int = TestNativeResource.deAllocator + + /** Call NativeResource.register to get the reference + */ + override val ref: NativeResourceRef = super.register() + /** + * Off-Heap Bytes Allocated for this object + */ + override val bytesAllocated: Long = 0 + } + object TestNativeResource { + def deAllocator(handle: CPtrAddress): Int = 0 + } + + object TestPhantomRef { def getRefQueue: ReferenceQueue[NativeResource] = { NativeResourceRef.refQ} def getRefMap: ConcurrentHashMap[NativeResourceRef, CPtrAddress] = {NativeResourceRef.refMap} def getCleaner: Thread = { NativeResourceRef.cleaner } + } - class TestRef(resource: NativeResource, - resourceDeAllocator: CPtrAddress => Int) + class TestPhantomRef(resource: NativeResource, + resourceDeAllocator: CPtrAddress => Int) extends NativeResourceRef(resource, resourceDeAllocator) { } - test(testName = "testAutoReleasefromScope") { + test(testName = "test NDArray Auto Release") { var a: NDArray = null var aRef: NativeResourceRef = null - val b: NDArray = ResourceScope.using() { - a = spy(NDArray.ones(Shape(3, 4))) - print("testAutoReleasefromScope: a address %x\n".format(a.nativeAddress)) - aRef = a.ref - val x = NDArray.ones(Shape(3, 4)) - print("testAutoReleasefromScope: x address %x\n".format(x.nativeAddress)) - x + var b: NDArray = null + + ResourceScope.using() { + b = ResourceScope.using() { + a = NDArray.ones(Shape(3, 4)) + print("testAutoReleasefromScope: a address %x\n".format(a.nativeAddress)) + aRef = a.ref + val x = NDArray.ones(Shape(3, 4)) + print("testAutoReleasefromScope: x address %x\n".format(x.nativeAddress)) + x + } + val bRef: NativeResourceRef = b.ref + assert(a.isDisposed == true, + "objects created within scope should have isDisposed set to true") + assert(b.isDisposed == false, + "returned NativeResource should not be released") + assert(TestPhantomRef.getRefMap.containsKey(aRef) == false, + "reference of resource in Scope should be removed refMap") + assert(TestPhantomRef.getRefMap.containsKey(bRef) == true, + "reference of resource outside scope should be not removed refMap") } - val bRef: NativeResourceRef = b.ref - assert(a.isDisposed == true, "objects created within scope should have isDisposed set to true") - assert(b.isDisposed == false, "returned NativeResource should not be released") - assert(TestRef.getRefMap.containsKey(aRef) == false, - "reference of resource in Scope should be removed refMap") - assert(TestRef.getRefMap.containsKey(bRef) == true, - "reference of resource outside scope should be not removed refMap") + assert(b.isDisposed, "resource returned from inner scope should be released in outer scope") + } + test("test return object release from outer scope") { + var a: TestNativeResource = null + ResourceScope.using() { + a = ResourceScope.using() { + new TestNativeResource() + } + assert(a.isDisposed == false, "returned object should not be disposed within Using") + } + assert(a.isDisposed == true, "returned object should be disposed in the outer scope") } - test("release from outerscope") { - var a: NDArray = null + test(testName = "test NativeResources in returned Lists are not disposed") { + var ndListRet: IndexedSeq[TestNativeResource] = null + ResourceScope.using() { + ndListRet = ResourceScope.using() { + val ndList: IndexedSeq[TestNativeResource] = + IndexedSeq(new TestNativeResource(), new TestNativeResource()) + ndList + } + ndListRet.foreach(nd => assert(nd.isDisposed == false, + "NativeResources within a returned collection should not be disposed")) + } + ndListRet.foreach(nd => assert(nd.isDisposed == true, + "NativeResources returned from inner scope should be disposed in outer scope")) + } + + test("test native resource inside a map") { + var nRInKeyOfMap: HashMap[TestNativeResource, String] = null + var nRInValOfMap: HashMap[String, TestNativeResource] = HashMap[String, TestNativeResource]() + + ResourceScope.using() { + nRInKeyOfMap = ResourceScope.using() { + val ret = HashMap[TestNativeResource, String]() + ret.put(new TestNativeResource, "hello") + ret + } + assert(!nRInKeyOfMap.isEmpty) + + nRInKeyOfMap.keysIterator.foreach(it => assert(it.isDisposed == false, + "NativeResources returned in Traversable should not be disposed")) + } + + nRInKeyOfMap.keysIterator.foreach(it => assert(it.isDisposed)) + + ResourceScope.using() { + + nRInValOfMap = ResourceScope.using() { + val ret = HashMap[String, TestNativeResource]() + ret.put("world!", new TestNativeResource) + ret + } + assert(!nRInValOfMap.isEmpty) + nRInValOfMap.valuesIterator.foreach(it => assert(it.isDisposed == false, + "NativeResources returned in Collection should not be disposed")) + } + nRInValOfMap.valuesIterator.foreach(it => assert(it.isDisposed)) } } From 21140cf695e019cebfc18dff9a2de05c990f4999 Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Sat, 13 Oct 2018 23:11:42 -0700 Subject: [PATCH 18/21] cleanup log/print debug statements --- .../src/main/scala/org/apache/mxnet/NativeResource.scala | 9 ++------- .../src/main/scala/org/apache/mxnet/ResourceScope.scala | 8 +------- .../test/scala/org/apache/mxnet/ResourceScopeSuite.scala | 2 -- .../init/src/main/scala/org/apache/mxnet/init/Base.scala | 1 - 4 files changed, 3 insertions(+), 17 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala index deecf22e5e65..c18f2333af04 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala @@ -94,7 +94,6 @@ private[mxnet] trait NativeResource private[mxnet] def dispose(removeFromScope: Boolean = true): Unit = { if (!disposed) { - print("NativeResource: Disposing NativeResource:%x\n".format(nativeAddress)) checkCall(nativeDeAllocator(this.nativeAddress)) deRegister(removeFromScope) NativeResource.totalBytesAllocated.getAndAdd(-1*bytesAllocated) @@ -115,9 +114,8 @@ private[mxnet] object NativeResource { var totalBytesAllocated : AtomicLong = new AtomicLong(0) } -/* Do not make resource a member of the class, -this will hold reference and GC will not clear the object. - */ +// Do not make [[NativeResource.resource]] a member of the class, +// this will hold reference and GC will not clear the object. private[mxnet] class NativeResourceRef(resource: NativeResource, val resourceDeAllocator: CPtrAddress => Int) extends PhantomReference[NativeResource](resource, NativeResourceRef.refQ) {} @@ -159,14 +157,11 @@ private[mxnet] object NativeResourceRef { } def cleanup: Unit = { - print("NativeResourceRef: cleanup\n") // remove is a blocking call val ref: NativeResourceRef = refQ.remove().asInstanceOf[NativeResourceRef] - print("NativeResourceRef: got a reference with deAlloc\n") // phantomRef will be removed from the map when NativeResource.close is called. val resource = refMap.get(ref) if (resource != 0L) { // since CPtrAddress is Scala a Long, it cannot be null - print("NativeResourceRef: got a reference for resource\n") ref.resourceDeAllocator(resource) refMap.remove(ref) } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala index ceb84f1b67a6..7c38335b51c9 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala @@ -22,15 +22,12 @@ import org.slf4j.LoggerFactory import scala.collection.mutable.ArrayBuffer import scala.util.Try import scala.util.control.{ControlThrowable, NonFatal} -import java.util.Comparator /** * This class manages automatically releasing of [[NativeResource]]s */ class ResourceScope extends AutoCloseable { - import ResourceScope.{logger, threadLocalScopes, addToLocalScope, removeFromLocalScope} - private[mxnet] val resourceQ = new ArrayBuffer[NativeResource] { // this override is required for object equality check instead of content equality override def indexOf[B >: NativeResource](elem: B, from: Int): Int = { @@ -59,7 +56,6 @@ class ResourceScope extends AutoCloseable { * @param resource */ def add(resource: NativeResource): Unit = { - logger.info("ResourceScope: Registering Resource %x".format(resource.nativeAddress)) resourceQ.+=(resource) } @@ -69,9 +65,7 @@ class ResourceScope extends AutoCloseable { * @param resource */ def remove(resource: NativeResource): Unit = { - logger.info("ResourceScope: DeRegistering Resource %x".format(resource.nativeAddress)) resourceQ.-=(resource) - logger.info("resourceQ size: %d".format(resourceQ.size)) } } @@ -83,7 +77,7 @@ object ResourceScope { * Captures all Native Resources created using the ResourceScope and * at the end of the body, de allocates all the Native resources by calling close on them. * This method will not deAllocate NativeResources returned from the block. - * @param scope Scope in which to capture the native resources + * @param scope (Optional). Scope in which to capture the native resources * @param body block of code to execute in this scope * @tparam A return type * @return result of the operation, if the result is of type NativeResource, it is not diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala index a57babd5a601..41dfa7d0ead2 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala @@ -73,10 +73,8 @@ class ResourceScopeSuite extends FunSuite with BeforeAndAfterAll with Matchers { ResourceScope.using() { b = ResourceScope.using() { a = NDArray.ones(Shape(3, 4)) - print("testAutoReleasefromScope: a address %x\n".format(a.nativeAddress)) aRef = a.ref val x = NDArray.ones(Shape(3, 4)) - print("testAutoReleasefromScope: x address %x\n".format(x.nativeAddress)) x } val bRef: NativeResourceRef = b.ref diff --git a/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala b/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala index a43df10fcb7e..7402dbd3bc1d 100644 --- a/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala +++ b/scala-package/init/src/main/scala/org/apache/mxnet/init/Base.scala @@ -48,7 +48,6 @@ object Base { if (os.startsWith("Linux")) { System.load(s"$baseDir/linux-x86_64/target/libmxnet-init-scala-linux-x86_64.so") } else if (os.startsWith("Mac")) { - baseDir = "/Users/wamy/nswamy/deepengine/workspace/mxnet_scala/scala-package/init-native" System.load(s"$baseDir/osx-x86_64/target/libmxnet-init-scala-osx-x86_64.jnilib") } else { // TODO(yizhi) support windows later From 362ae18bc3d9d65e5527db725838c33bc47c0576 Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Mon, 15 Oct 2018 16:49:41 -0700 Subject: [PATCH 19/21] use TreeSet inplace of ArrayBuffer to speedup removal of resources from ResourceScope Fix Executor dispose and make KVStore a NativeResource --- .../scala/org/apache/mxnet/Executor.scala | 6 ++ .../main/scala/org/apache/mxnet/KVStore.scala | 21 ++-- .../main/scala/org/apache/mxnet/Model.scala | 102 +++++++++--------- .../org/apache/mxnet/ResourceScope.scala | 17 +-- .../main/scala/org/apache/mxnet/Symbol.scala | 4 +- 5 files changed, 78 insertions(+), 72 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala index 581109e3ba07..612d98d566a0 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala @@ -64,6 +64,12 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, // cannot determine the off-heap size of this object override val bytesAllocated: Long = 0 override val ref: NativeResourceRef = super.register() + override def dispose(): Unit = { + if (!disposed) { + super.dispose() + outputs.foreach(o => o.dispose()) + } + } /** * Return a new executor with the same symbol and shared memory, diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/KVStore.scala b/scala-package/core/src/main/scala/org/apache/mxnet/KVStore.scala index 8e89ce76b877..45189a13aefc 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/KVStore.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/KVStore.scala @@ -52,22 +52,17 @@ object KVStore { } } -class KVStore(private[mxnet] val handle: KVStoreHandle) extends WarnIfNotDisposed { +class KVStore(private[mxnet] val handle: KVStoreHandle) extends NativeResource { private val logger: Logger = LoggerFactory.getLogger(classOf[KVStore]) private var updaterFunc: MXKVStoreUpdater = null - private var disposed = false - protected def isDisposed = disposed - /** - * Release the native memory. - * The object shall never be used after it is disposed. - */ - def dispose(): Unit = { - if (!disposed) { - _LIB.mxKVStoreFree(handle) - disposed = true - } - } + override def nativeAddress: CPtrAddress = handle + + override def nativeDeAllocator: CPtrAddress => MXUint = _LIB.mxKVStoreFree + + override val ref: NativeResourceRef = super.register() + + override val bytesAllocated: Long = 0L /** * Initialize a single or a sequence of key-value pairs into the store. diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Model.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Model.scala index 4bb9cdd331a6..3d4c58dfc45a 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Model.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Model.scala @@ -280,7 +280,7 @@ object Model { if (updateOnKVStore) { kvStore.foreach(_.setOptimizer(optimizer)) } - + ResourceScope.using() { // Now start training for (epoch <- beginEpoch until endEpoch) { // Training phase @@ -290,45 +290,46 @@ object Model { var epochDone = false // Iterate over training data. trainData.reset() - while (!epochDone) { - var doReset = true - while (doReset && trainData.hasNext) { - val dataBatch = trainData.next() - executorManager.loadDataBatch(dataBatch) - monitor.foreach(_.tic()) - executorManager.forward(isTrain = true) - executorManager.backward() - if (updateOnKVStore) { - updateParamsOnKVStore(executorManager.paramArrays, - executorManager.gradArrays, - kvStore, executorManager.paramNames) - } else { - updateParams(executorManager.paramArrays, - executorManager.gradArrays, - updaterLocal, ctx.length, - executorManager.paramNames, - kvStore) - } - monitor.foreach(_.tocPrint()) - // evaluate at end, so out_cpu_array can lazy copy - executorManager.updateMetric(evalMetric, dataBatch.label) +// ResourceScope.using() { + while (!epochDone) { + var doReset = true + while (doReset && trainData.hasNext) { + val dataBatch = trainData.next() + executorManager.loadDataBatch(dataBatch) + monitor.foreach(_.tic()) + executorManager.forward(isTrain = true) + executorManager.backward() + if (updateOnKVStore) { + updateParamsOnKVStore(executorManager.paramArrays, + executorManager.gradArrays, + kvStore, executorManager.paramNames) + } else { + updateParams(executorManager.paramArrays, + executorManager.gradArrays, + updaterLocal, ctx.length, + executorManager.paramNames, + kvStore) + } + monitor.foreach(_.tocPrint()) + // evaluate at end, so out_cpu_array can lazy copy + executorManager.updateMetric(evalMetric, dataBatch.label) - nBatch += 1 - batchEndCallback.foreach(_.invoke(epoch, nBatch, evalMetric)) + nBatch += 1 + batchEndCallback.foreach(_.invoke(epoch, nBatch, evalMetric)) - // this epoch is done possibly earlier - if (epochSize != -1 && nBatch >= epochSize) { - doReset = false + // this epoch is done possibly earlier + if (epochSize != -1 && nBatch >= epochSize) { + doReset = false + } + } + if (doReset) { + trainData.reset() } - } - if (doReset) { - trainData.reset() - } - - // this epoch is done - epochDone = (epochSize == -1 || nBatch >= epochSize) - } + // this epoch is done + epochDone = (epochSize == -1 || nBatch >= epochSize) + } +// } val (name, value) = evalMetric.get name.zip(value).foreach { case (n, v) => logger.info(s"Epoch[$epoch] Train-$n=$v") @@ -336,20 +337,22 @@ object Model { val toc = System.currentTimeMillis logger.info(s"Epoch[$epoch] Time cost=${toc - tic}") - evalData.foreach { evalDataIter => - evalMetric.reset() - evalDataIter.reset() - // TODO: make DataIter implement Iterator - while (evalDataIter.hasNext) { - val evalBatch = evalDataIter.next() - executorManager.loadDataBatch(evalBatch) - executorManager.forward(isTrain = false) - executorManager.updateMetric(evalMetric, evalBatch.label) - } + ResourceScope.using() { + evalData.foreach { evalDataIter => + evalMetric.reset() + evalDataIter.reset() + // TODO: make DataIter implement Iterator + while (evalDataIter.hasNext) { + val evalBatch = evalDataIter.next() + executorManager.loadDataBatch(evalBatch) + executorManager.forward(isTrain = false) + executorManager.updateMetric(evalMetric, evalBatch.label) + } - val (name, value) = evalMetric.get - name.zip(value).foreach { case (n, v) => - logger.info(s"Epoch[$epoch] Train-$n=$v") + val (name, value) = evalMetric.get + name.zip(value).foreach { case (n, v) => + logger.info(s"Epoch[$epoch] Train-$n=$v") + } } } @@ -361,6 +364,7 @@ object Model { updaterLocal.dispose() executorManager.dispose() + } } // scalastyle:on parameterNum } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala index 7c38335b51c9..64ec772392af 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala @@ -17,8 +17,11 @@ package org.apache.mxnet +import java.util.HashSet + import org.slf4j.LoggerFactory +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.Try import scala.util.control.{ControlThrowable, NonFatal} @@ -28,14 +31,12 @@ import scala.util.control.{ControlThrowable, NonFatal} */ class ResourceScope extends AutoCloseable { - private[mxnet] val resourceQ = new ArrayBuffer[NativeResource] { - // this override is required for object equality check instead of content equality - override def indexOf[B >: NativeResource](elem: B, from: Int): Int = { - indexWhere(elem.asInstanceOf[NativeResource].nativeAddress == - _.nativeAddress, from) - } - override def lastIndexOf[B >: NativeResource](elem: B): Int = { - lastIndexWhere(elem.asInstanceOf[NativeResource].nativeAddress == _.nativeAddress) + // HashSet does not take a custom comparator + private[mxnet] val resourceQ = new mutable.TreeSet[NativeResource]()(nativeAddressOrdering) + + private object nativeAddressOrdering extends Ordering[NativeResource] { + def compare(a: NativeResource, b: NativeResource): Int = { + a.nativeAddress compare b.nativeAddress } } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala index b45f9dcca465..a009e7e343f2 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala @@ -787,7 +787,7 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeReso } val execHandle = new ExecutorHandleRef - val sharedHadle = if (sharedExec != null) sharedExec.handle else 0L + val sharedHandle = if (sharedExec != null) sharedExec.handle else 0L checkCall(_LIB.mxExecutorBindEX(handle, ctx.deviceTypeid, ctx.deviceId, @@ -800,7 +800,7 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends NativeReso argsGradHandle, reqsArray, auxArgsHandle, - sharedHadle, + sharedHandle, execHandle)) val executor = new Executor(execHandle.value, this.clone()) executor.argArrays = argsNDArray From d78f571e8214344864076acab521f83915232b42 Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Thu, 18 Oct 2018 11:59:50 -0700 Subject: [PATCH 20/21] fix segfault that was happening because of NDArray creation on the fly in Optimizer --- .../scala/org/apache/mxnet/Executor.scala | 2 +- .../main/scala/org/apache/mxnet/Model.scala | 32 +++++++++---------- .../main/scala/org/apache/mxnet/NDArray.scala | 2 +- .../org/apache/mxnet/NativeResource.scala | 10 +++--- .../scala/org/apache/mxnet/Optimizer.scala | 22 +++++++++---- .../org/apache/mxnet/ResourceScope.scala | 16 +++++----- .../org/apache/mxnet/optimizer/SGD.scala | 10 ++++-- 7 files changed, 55 insertions(+), 39 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala index 612d98d566a0..19fb6fe5cee5 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala @@ -65,7 +65,7 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, override val bytesAllocated: Long = 0 override val ref: NativeResourceRef = super.register() override def dispose(): Unit = { - if (!disposed) { + if (!super.isDisposed) { super.dispose() outputs.foreach(o => o.dispose()) } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Model.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Model.scala index 3d4c58dfc45a..b835c4964dd0 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Model.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Model.scala @@ -259,7 +259,9 @@ object Model { workLoadList: Seq[Float] = Nil, monitor: Option[Monitor] = None, symGen: SymbolGenerator = null): Unit = { - val executorManager = new DataParallelExecutorManager( + ResourceScope.using() { + + val executorManager = new DataParallelExecutorManager( symbol = symbol, symGen = symGen, ctx = ctx, @@ -269,18 +271,18 @@ object Model { auxNames = auxNames, workLoadList = workLoadList) - monitor.foreach(executorManager.installMonitor) - executorManager.setParams(argParams, auxParams) + monitor.foreach(executorManager.installMonitor) + executorManager.setParams(argParams, auxParams) - // updater for updateOnKVStore = false - val updaterLocal = Optimizer.getUpdater(optimizer) + // updater for updateOnKVStore = false + val updaterLocal = Optimizer.getUpdater(optimizer) + + kvStore.foreach(initializeKVStore(_, executorManager.paramArrays, + argParams, executorManager.paramNames, updateOnKVStore)) + if (updateOnKVStore) { + kvStore.foreach(_.setOptimizer(optimizer)) + } - kvStore.foreach(initializeKVStore(_, executorManager.paramArrays, - argParams, executorManager.paramNames, updateOnKVStore)) - if (updateOnKVStore) { - kvStore.foreach(_.setOptimizer(optimizer)) - } - ResourceScope.using() { // Now start training for (epoch <- beginEpoch until endEpoch) { // Training phase @@ -290,7 +292,7 @@ object Model { var epochDone = false // Iterate over training data. trainData.reset() -// ResourceScope.using() { + ResourceScope.using() { while (!epochDone) { var doReset = true while (doReset && trainData.hasNext) { @@ -329,7 +331,7 @@ object Model { // this epoch is done epochDone = (epochSize == -1 || nBatch >= epochSize) } -// } + } val (name, value) = evalMetric.get name.zip(value).foreach { case (n, v) => logger.info(s"Epoch[$epoch] Train-$n=$v") @@ -351,7 +353,7 @@ object Model { val (name, value) = evalMetric.get name.zip(value).foreach { case (n, v) => - logger.info(s"Epoch[$epoch] Train-$n=$v") + logger.info(s"Epoch[$epoch] Validation-$n=$v") } } } @@ -362,8 +364,6 @@ object Model { epochEndCallback.foreach(_.invoke(epoch, symbol, argParams, auxParams)) } - updaterLocal.dispose() - executorManager.dispose() } } // scalastyle:on parameterNum diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index f039a0d88ab3..f2a7603caa85 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -589,7 +589,7 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, * The object shall never be used after it is disposed. */ override def dispose(): Unit = { - if (!disposed) { + if (!super.isDisposed) { super.dispose() dependencies.clear() } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala index c18f2333af04..e3e7c7a3aa70 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala @@ -56,9 +56,9 @@ private[mxnet] trait NativeResource // intentionally making it a val, so it gets evaluated when defined val bytesAllocated: Long - private[mxnet] var scope: ResourceScope = null + private[mxnet] var scope: Option[ResourceScope] = None - @volatile var disposed = false + @volatile private var disposed = false override def isDisposed: Boolean = disposed || isDeAllocated @@ -70,7 +70,7 @@ private[mxnet] trait NativeResource */ def register(): NativeResourceRef = { scope = ResourceScope.getCurrentScope() - if (scope != null) scope.add(this) + if (scope.isDefined) scope.get.add(this) NativeResource.totalBytesAllocated.getAndAdd(bytesAllocated) // register with PhantomRef tracking to release incase the objects go @@ -81,7 +81,7 @@ private[mxnet] trait NativeResource // Removes this object from PhantomRef tracking and from ResourceScope private def deRegister(removeFromScope: Boolean): Unit = { NativeResourceRef.deRegister(ref) - if (scope != null && removeFromScope) scope.remove(this) + if (scope.isDefined && removeFromScope) scope.get.remove(this) } // Implements [[@link AutoCloseable.close]] @@ -183,4 +183,4 @@ private[mxnet] object NativeResourceRef { } } } -} \ No newline at end of file +} diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala index 758cbc829618..c3f8aaec6d60 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala @@ -19,6 +19,8 @@ package org.apache.mxnet import java.io._ +import org.apache.mxnet.Base.CPtrAddress + import scala.collection.mutable import scala.util.Either @@ -38,8 +40,10 @@ object Optimizer { } override def dispose(): Unit = { - states.values.foreach(optimizer.disposeState) - states.clear() + if (!super.isDisposed) { + states.values.foreach(optimizer.disposeState) + states.clear() + } } override def serializeState(): Array[Byte] = { @@ -285,7 +289,8 @@ abstract class Optimizer extends Serializable { } } -trait MXKVStoreUpdater { +trait MXKVStoreUpdater extends + NativeResource { /** * user-defined updater for the kvstore * It's this updater's responsibility to delete recv and local @@ -294,9 +299,14 @@ trait MXKVStoreUpdater { * @param local the value stored on local on this key */ def update(key: Int, recv: NDArray, local: NDArray): Unit - def dispose(): Unit - // def serializeState(): Array[Byte] - // def deserializeState(bytes: Array[Byte]): Unit + + // This is a hack to make Optimizers work with ResourceScope + // otherwise the user has to manage calling dispose on this object. + override def nativeAddress: CPtrAddress = hashCode() + override def nativeDeAllocator: CPtrAddress => Int = doNothingDeAllocator + private def doNothingDeAllocator(dummy: CPtrAddress): Int = 0 + override val ref: NativeResourceRef = super.register() + override val bytesAllocated: Long = 0L } trait MXKVStoreCachedStates { diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala index 64ec772392af..1c5782d873a9 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala @@ -40,16 +40,16 @@ class ResourceScope extends AutoCloseable { } } - ResourceScope.addToLocalScope(this) + ResourceScope.addToThreadLocal(this) /** * Releases all the [[NativeResource]] by calling * the associated [[NativeResource.close()]] method */ override def close(): Unit = { + ResourceScope.removeFromThreadLocal(this) resourceQ.foreach(resource => if (resource != null) resource.dispose(false) ) resourceQ.clear() - ResourceScope.removeFromLocalScope(this) } /** @@ -114,7 +114,7 @@ object ResourceScope { curScope.remove(r) if (prevScope.isDefined) { prevScope.get.add(r) - r.scope = prevScope.get + r.scope = prevScope } } @@ -165,7 +165,7 @@ object ResourceScope { * Add resource to current ThreadLocal DataStructure * @param r ResourceScope to add. */ - private[mxnet] def addToLocalScope(r: ResourceScope): Unit = { + private[mxnet] def addToThreadLocal(r: ResourceScope): Unit = { threadLocalScopes.get() += r } @@ -173,7 +173,7 @@ object ResourceScope { * Remove resource from current ThreadLocal DataStructure * @param r ResourceScope to remove */ - private[mxnet] def removeFromLocalScope(r: ResourceScope): Unit = { + private[mxnet] def removeFromThreadLocal(r: ResourceScope): Unit = { threadLocalScopes.get() -= r } @@ -181,8 +181,8 @@ object ResourceScope { * Get the latest Scope in the stack * @return */ - private[mxnet] def getCurrentScope(): ResourceScope = { - Try(threadLocalScopes.get().last).getOrElse(null) + private[mxnet] def getCurrentScope(): Option[ResourceScope] = { + Try(Some(threadLocalScopes.get().last)).getOrElse(None) } /** @@ -193,4 +193,4 @@ object ResourceScope { val scopes = threadLocalScopes.get() Try(Some(scopes(scopes.size - 2))).getOrElse(None) } -} \ No newline at end of file +} diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala index e20b433ed1ed..d349feac3e93 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala @@ -17,7 +17,7 @@ package org.apache.mxnet.optimizer -import org.apache.mxnet.{Optimizer, LRScheduler, NDArray} +import org.apache.mxnet._ import org.apache.mxnet.NDArrayConversions._ /** @@ -92,7 +92,13 @@ class SGD(val learningRate: Float = 0.01f, momentum: Float = 0.0f, if (momentum == 0.0f) { null } else { - NDArray.zeros(weight.shape, weight.context) + val s = NDArray.zeros(weight.shape, weight.context) + // this is created on the fly and shared between runs, + // we don't want it to be dispose from the scope + // and should be handled by the dispose + val scope = ResourceScope.getCurrentScope() + if (scope.isDefined) scope.get.remove(s) + s } } From f0e873b07222b572125223b1770541fe403854e0 Mon Sep 17 00:00:00 2001 From: Naveen Swamy Date: Thu, 18 Oct 2018 13:15:09 -0700 Subject: [PATCH 21/21] Add comments for dispose(param:Boolean) --- .../org/apache/mxnet/NativeResource.scala | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala index e3e7c7a3aa70..48d4b0c193b1 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala @@ -78,12 +78,6 @@ private[mxnet] trait NativeResource NativeResourceRef.register(this, nativeDeAllocator) } - // Removes this object from PhantomRef tracking and from ResourceScope - private def deRegister(removeFromScope: Boolean): Unit = { - NativeResourceRef.deRegister(ref) - if (scope.isDefined && removeFromScope) scope.get.remove(this) - } - // Implements [[@link AutoCloseable.close]] override def close(): Unit = { dispose() @@ -92,10 +86,22 @@ private[mxnet] trait NativeResource // Implements [[@link WarnIfNotDisposed.dispose]] def dispose(): Unit = dispose(true) + /** + * This method deAllocates nativeResource and deRegisters + * from PhantomRef and removes from Scope if + * removeFromScope is set to true. + * @param removeFromScope remove from the currentScope if true + */ + // the parameter here controls whether to remove from current scope. + // [[ResourceScope.close]] calls NativeResource.dispose + // if we remove from the ResourceScope ie., from the container in ResourceScope. + // while iterating on the container, calling iterator.next is undefined and not safe. + // Note that ResourceScope automatically disposes all the resources within. private[mxnet] def dispose(removeFromScope: Boolean = true): Unit = { if (!disposed) { checkCall(nativeDeAllocator(this.nativeAddress)) - deRegister(removeFromScope) + NativeResourceRef.deRegister(ref) // removes from PhantomRef tracking + if (removeFromScope && scope.isDefined) scope.get.remove(this) NativeResource.totalBytesAllocated.getAndAdd(-1*bytesAllocated) disposed = true } @@ -138,11 +144,8 @@ private[mxnet] object NativeResourceRef { ref } - def deRegister(ref: NativeResourceRef): Unit = { - if (refMap.containsKey(ref)) { - refMap.remove(ref) - } - } + // remove from PhantomRef tracking + def deRegister(ref: NativeResourceRef): Unit = refMap.remove(ref) /** * This method will check if the cleaner ran and deAllocated the object