Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Scala] fix memory leak for training, mnist example tested #5

Open
wants to merge 1 commit into
base: scala-nd-collector
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ private[mxnet] object ExecutorManager {

// Load a list of arrays into a list of arrays specified by slices
private[mxnet] def loadGeneralMulti(data: Seq[NDArray],
targets: Seq[Array[(Int, Int, NDArray)]]): Unit = {
targets: Seq[Array[(Int, Int, NDArray)]])
: Unit = NDArrayCollector.auto().withScope {
for ((src, dTargets) <- data zip targets) {
for ((start, end, dst) <- dTargets) {
val sliced = src.slice(start, end)
Expand Down Expand Up @@ -522,7 +523,8 @@ private class DataParallelExecutorGroup private(sym: Symbol,
}

// Update evaluation metric with label and current outputs
def updateMetric(metric: EvalMetric, labels: IndexedSeq[NDArray]): Unit = {
def updateMetric(metric: EvalMetric, labels: IndexedSeq[NDArray])
: Unit = NDArrayCollector.auto().withScope {
(trainExecs zip slices).foreach { case (texec, islice) =>
val labelsSlice = labels.map(_.slice(islice))
metric.update(labelsSlice, texec.outputs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,13 @@ class FeedForward private(
}
}

if (_argParams != null) {
_argParams.values.foreach(_.dispose())
}
if (_auxParams != null) {
_auxParams.values.foreach(_.dispose())
}

_argParams = argParams
_auxParams = auxParams
(argNames, paramNames, auxNames)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,8 @@ object Model {
if (epochSize != -1 && nBatch >= epochSize) {
doReset = false
}

dataBatch.dispose()
}
if (doReset) {
trainData.reset()
Expand All @@ -344,6 +346,7 @@ object Model {
executorManager.loadDataBatch(evalBatch)
executorManager.forward(isTrain = false)
executorManager.updateMetric(evalMetric, evalBatch.label)
evalBatch.dispose()
}

val (name, value) = evalMetric.get
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
_provideLabel: ListMap[String, Shape],
_batchSize: Int) =
if (hasNext) {
iterNext()
val data = currentBatch.data(0)
val label = currentBatch.label(0)
// properties
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.mxnet.optimizer

import org.apache.mxnet.{Optimizer, LRScheduler, NDArray}
import org.apache.mxnet.{LRScheduler, NDArray, NDArrayCollector, Optimizer}
import org.apache.mxnet.NDArrayConversions._

/**
Expand All @@ -39,7 +39,8 @@ class SGD(val learningRate: Float = 0.01f, momentum: Float = 0.0f,
* @param state NDArray or other objects returned by initState
* The auxiliary state used in optimization.
*/
override def update(index: Int, weight: NDArray, grad: NDArray, state: AnyRef): Unit = {
override def update(index: Int, weight: NDArray, grad: NDArray, state: AnyRef)
: Unit = NDArrayCollector.auto().withScope {
// TODO(bing) implement wd_bias, wd_gamma, wd_beta (copy from python package)
var lr = {
if (lrScheduler != null) {
Expand All @@ -58,7 +59,6 @@ class SGD(val learningRate: Float = 0.01f, momentum: Float = 0.0f,
// to get rid of memory leak
val oldResdGrad = resdGrad
resdGrad = NDArray.clip(resdGrad, -clipGradient, clipGradient)
oldResdGrad.dispose()
}

if (state != null) {
Expand All @@ -71,7 +71,6 @@ class SGD(val learningRate: Float = 0.01f, momentum: Float = 0.0f,
adder *= (-lr)
mom += adder
weight += mom
adder.dispose()
} else {
require(momentum == 0f)
// adder = -lr * (resdGrad + this.wd * weight)
Expand All @@ -80,10 +79,7 @@ class SGD(val learningRate: Float = 0.01f, momentum: Float = 0.0f,
adder += resdGrad
adder *= (-lr)
weight += adder
adder.dispose()
}

resdGrad.dispose()
}

// Create additional optimizer state such as momentum.
Expand Down