Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions core/src/main/scala/org/apache/spark/Accumulators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.lang.ThreadLocal

import scala.collection.generic.Growable
import scala.collection.mutable.Map
import scala.ref.WeakReference
import scala.reflect.ClassTag

import org.apache.spark.serializer.JavaSerializer
Expand Down Expand Up @@ -280,10 +281,12 @@ object AccumulatorParam {
// TODO: The multi-thread support in accumulators is kind of lame; check
// if there's a more intuitive way of doing it right
private[spark] object Accumulators {
// TODO: Use soft references? => need to make readObject work properly then
val originals = Map[Long, Accumulable[_, _]]()
val localAccums = new ThreadLocal[Map[Long, Accumulable[_, _]]]() {
override protected def initialValue() = Map[Long, Accumulable[_, _]]()
// Store a WeakReference instead of a StrongReference because this way accumulators can be
// appropriately garbage collected during long-running jobs and release memory
type WeakAcc = WeakReference[Accumulable[_, _]]
val originals = Map[Long, WeakAcc]()
val localAccums = new ThreadLocal[Map[Long, WeakAcc]]() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guava MapMaper supports weakValues; not sure if we want to use that here, since it's not super Scala-friendly (e.g. returns nulls, etc): http://docs.guava-libraries.googlecode.com/git-history/release/javadoc/com/google/common/collect/MapMaker.html

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Josh - are you suggesting to replace this snippet with a MapMaker just to simplify the initialization code? I believe the usage of either object would be the same - do you see a specific advantage to trying to use the MapMaker?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just leave this as-is; I don't think MapMaker will buy us much now that I think about it.

override protected def initialValue() = Map[Long, WeakAcc]()
}
var lastId: Long = 0

Expand All @@ -294,9 +297,9 @@ private[spark] object Accumulators {

def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized {
if (original) {
originals(a.id) = a
originals(a.id) = new WeakAcc(a)
} else {
localAccums.get()(a.id) = a
localAccums.get()(a.id) = new WeakAcc(a)
}
}

Expand All @@ -307,11 +310,22 @@ private[spark] object Accumulators {
}
}

def remove(accId : Long) {
synchronized {
originals.remove(accId)
}
}

// Get the values of the local accumulators for the current thread (by ID)
def values: Map[Long, Any] = synchronized {
val ret = Map[Long, Any]()
for ((id, accum) <- localAccums.get) {
ret(id) = accum.localValue
// Since we are now storing weak references, we must check whether the underlying data
// is valid.
ret(id) = accum.get match {
case Some(values) => values.localValue
case None => None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was dumb for me to overlook, too: if the data structure is invalid, then this would just silently ignore it; even if this was a rare error-condition, there should have been a warning here.

}
}
return ret
}
Expand All @@ -320,7 +334,13 @@ private[spark] object Accumulators {
def add(values: Map[Long, Any]): Unit = synchronized {
for ((id, value) <- values) {
if (originals.contains(id)) {
originals(id).asInstanceOf[Accumulable[Any, Any]] ++= value
// Since we are now storing weak references, we must check whether the underlying data
// is valid.
originals(id).get match {
case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]] ++= value
case None =>
throw new IllegalAccessError("Attempted to access garbage collected Accumulator.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If Accumulator is garbage collected, should we log and continue ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The exception thrown here is caught at higher levels of the stack. For example, DAGScheduler wraps calls to accumulator methods in a try block and logs any uncaught exceptions. Have you run into a case where the current behavior causes a problem?

}
}
}
}
Expand Down
20 changes: 20 additions & 0 deletions core/src/main/scala/org/apache/spark/ContextCleaner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ private sealed trait CleanupTask
private case class CleanRDD(rddId: Int) extends CleanupTask
private case class CleanShuffle(shuffleId: Int) extends CleanupTask
private case class CleanBroadcast(broadcastId: Long) extends CleanupTask
private case class CleanAccum(accId: Long) extends CleanupTask

/**
* A WeakReference associated with a CleanupTask.
Expand Down Expand Up @@ -114,6 +115,10 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
registerForCleanup(rdd, CleanRDD(rdd.id))
}

def registerAccumulatorForCleanup(a: Accumulable[_, _]): Unit = {
registerForCleanup(a, CleanAccum(a.id))
}

/** Register a ShuffleDependency for cleanup when it is garbage collected. */
def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _, _]) {
registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId))
Expand Down Expand Up @@ -145,6 +150,8 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
doCleanupShuffle(shuffleId, blocking = blockOnShuffleCleanupTasks)
case CleanBroadcast(broadcastId) =>
doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks)
case CleanAccum(accId) =>
doCleanupAccum(accId, blocking = blockOnCleanupTasks)
}
}
} catch {
Expand Down Expand Up @@ -190,6 +197,18 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
}
}

/** Perform accumulator cleanup. */
def doCleanupAccum(accId: Long, blocking: Boolean) {
try {
logDebug("Cleaning accumulator " + accId)
Accumulators.remove(accId)
listeners.foreach(_.accumCleaned(accId))
logInfo("Cleaned accumulator " + accId)
} catch {
case e: Exception => logError("Error cleaning accumulator " + accId, e)
}
}

private def blockManagerMaster = sc.env.blockManager.master
private def broadcastManager = sc.env.broadcastManager
private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
Expand All @@ -206,4 +225,5 @@ private[spark] trait CleanerListener {
def rddCleaned(rddId: Int)
def shuffleCleaned(shuffleId: Int)
def broadcastCleaned(broadcastId: Long)
def accumCleaned(accId: Long)
}
29 changes: 22 additions & 7 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -986,15 +986,21 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* values to using the `+=` method. Only the driver can access the accumulator's `value`.
*/
def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) =
new Accumulator(initialValue, param)
{
val acc = new Accumulator(initialValue, param)
cleaner.foreach(_.registerAccumulatorForCleanup(acc))
acc
}

/**
* Create an [[org.apache.spark.Accumulator]] variable of a given type, with a name for display
* in the Spark UI. Tasks can "add" values to the accumulator using the `+=` method. Only the
* driver can access the accumulator's `value`.
*/
def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) = {
new Accumulator(initialValue, param, Some(name))
val acc = new Accumulator(initialValue, param, Some(name))
cleaner.foreach(_.registerAccumulatorForCleanup(acc))
acc
}

/**
Expand All @@ -1003,9 +1009,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* @tparam R accumulator result type
* @tparam T type that can be added to the accumulator
*/
def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T]) =
new Accumulable(initialValue, param)

def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T]) = {
val acc = new Accumulable(initialValue, param)
cleaner.foreach(_.registerAccumulatorForCleanup(acc))
acc
}

/**
* Create an [[org.apache.spark.Accumulable]] shared variable, with a name for display in the
* Spark UI. Tasks can add values to the accumuable using the `+=` operator. Only the driver can
Expand All @@ -1014,7 +1023,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* @tparam T type that can be added to the accumulator
*/
def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) =
new Accumulable(initialValue, param, Some(name))
{
val acc = new Accumulable(initialValue, param, Some(name))
cleaner.foreach(_.registerAccumulatorForCleanup(acc))
acc
}

/**
* Create an accumulator from a "mutable collection" type.
Expand All @@ -1025,7 +1038,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T]
(initialValue: R): Accumulable[R, T] = {
val param = new GrowableAccumulableParam[R,T]
new Accumulable(initialValue, param)
val acc = new Accumulable(initialValue, param)
cleaner.foreach(_.registerAccumulatorForCleanup(acc))
acc
}

/**
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ abstract class RDD[T: ClassTag](
new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), true, seed)
}
}

/**
* Randomly splits this RDD with the provided weights.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -890,8 +890,16 @@ class DAGScheduler(
if (event.accumUpdates != null) {
try {
Accumulators.add(event.accumUpdates)

event.accumUpdates.foreach { case (id, partialValue) =>
val acc = Accumulators.originals(id).asInstanceOf[Accumulable[Any, Any]]
// In this instance, although the reference in Accumulators.originals is a WeakRef,
// it's guaranteed to exist since the event.accumUpdates Map exists

val acc = Accumulators.originals(id).get match {
case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]]
case None => throw new NullPointerException("Non-existent reference to Accumulator")
}

// To avoid UI cruft, ignore cases where value wasn't updated
if (acc.name.isDefined && partialValue != acc.zero) {
val name = acc.name.get
Expand Down
21 changes: 21 additions & 0 deletions core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import scala.collection.mutable
import org.scalatest.FunSuite
import org.scalatest.Matchers

import scala.ref.WeakReference


class AccumulatorSuite extends FunSuite with Matchers with LocalSparkContext {

Expand Down Expand Up @@ -135,5 +137,24 @@ class AccumulatorSuite extends FunSuite with Matchers with LocalSparkContext {
resetSparkContext()
}
}

test ("garbage collection") {
// Create an accumulator and let it go out of scope to test that it's properly garbage collected
sc = new SparkContext("local", "test")
var acc: Accumulable[mutable.Set[Any], Any] = sc.accumulable(new mutable.HashSet[Any]())
val accId = acc.id
val ref = WeakReference(acc)

// Ensure the accumulator is present
assert(ref.get.isDefined)

// Remove the explicit reference to it and allow weak reference to get garbage collected
acc = null
System.gc()
assert(ref.get.isEmpty)

Accumulators.remove(accId)
assert(!Accumulators.originals.get(accId).isDefined)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,10 @@ class CleanerTester(
toBeCleanedBroadcstIds -= broadcastId
logInfo("Broadcast" + broadcastId + " cleaned")
}

def accumCleaned(accId: Long) : Unit = {
logInfo("Cleaned accId " + accId + " cleaned")
}
}

val MAX_VALIDATION_ATTEMPTS = 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,11 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar
completeWithAccumulator(accum.id, taskSets(0), Seq((Success, 42)))
completeWithAccumulator(accum.id, taskSets(0), Seq((Success, 42)))
assert(results === Map(0 -> 42))
assert(Accumulators.originals(accum.id).value === 1)

val accVal = Accumulators.originals(accum.id).get.get.value

assert(accVal === 1)

assertDataStructuresEmpty
}

Expand Down