diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala index f51424b7edf6..b0fae0f9d58d 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala @@ -61,6 +61,8 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, protected var monitorCallback: MXMonitorCallback = null private val logger: Logger = LoggerFactory.getLogger(classOf[Executor]) + private var reshaped = false + override def nativeAddress: CPtrAddress = handle override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxExecutorFree // cannot determine the off-heap size of this object @@ -71,12 +73,12 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, if (!super.isDisposed) { super.dispose() outputs.foreach(o => o.dispose()) - if (argArrays != null) {argArrays.foreach(a => a.dispose())} - if (gradArrays != null) {gradArrays.foreach( + if (reshaped && argArrays != null) {argArrays.foreach(a => a.dispose())} + if (reshaped && gradArrays != null) {gradArrays.foreach( // Symbol will sometimes fill this with nulls so we've got to check the elements too a => if (a != null) {a.dispose()}) } - if (auxArrays != null) {auxArrays.foreach(a => a.dispose())} + if (reshaped && auxArrays != null) {auxArrays.foreach(a => a.dispose())} } } @@ -146,6 +148,7 @@ class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle, executor.argArrays = argArrays executor.gradArrays = gradArrays executor.auxArrays = auxArrays + executor.reshaped = true executor } diff --git a/src/common/utils.h b/src/common/utils.h index 6cdb869ff9ae..251a8fe3c190 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -654,7 +654,7 @@ FCompType GetFCompute(const nnvm::Op* op, const std::string& name, } else if (ctx.dev_mask() == gpu::kDevMask) { return fcompute_gpu.get(op, nullptr); } else { - LOG(FATAL) << "Unknown device mask"; + LOG(FATAL) << "Unknown device mask " << ctx.dev_mask(); return nullptr; } }