Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 23 additions & 17 deletions core/src/main/scala/org/apache/spark/Accumulators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -280,15 +280,24 @@ object AccumulatorParam {

// TODO: The multi-thread support in accumulators is kind of lame; check
// if there's a more intuitive way of doing it right
private[spark] object Accumulators {
// Store a WeakReference instead of a StrongReference because this way accumulators can be
// appropriately garbage collected during long-running jobs and release memory
type WeakAcc = WeakReference[Accumulable[_, _]]
val originals = Map[Long, WeakAcc]()
val localAccums = new ThreadLocal[Map[Long, WeakAcc]]() {
override protected def initialValue() = Map[Long, WeakAcc]()
private[spark] object Accumulators extends Logging {
/**
* This global map holds the original accumulator objects that are created on the driver.
* It keeps weak references to these objects so that accumulators can be garbage-collected
* once the RDDs and user-code that reference them are cleaned up.
*/
val originals = Map[Long, WeakReference[Accumulable[_, _]]]()

/**
* This thread-local map holds per-task copies of accumulators; it is used to collect the set
* of accumulator updates to send back to the driver when tasks complete. After tasks complete,
* this map is cleared by `Accumulators.clear()` (see Executor.scala).
*/
private val localAccums = new ThreadLocal[Map[Long, Accumulable[_, _]]]() {
override protected def initialValue() = Map[Long, Accumulable[_, _]]()
}
var lastId: Long = 0

private var lastId: Long = 0

def newId(): Long = synchronized {
lastId += 1
Expand All @@ -297,16 +306,16 @@ private[spark] object Accumulators {

def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized {
if (original) {
originals(a.id) = new WeakAcc(a)
originals(a.id) = new WeakReference[Accumulable[_, _]](a)
} else {
localAccums.get()(a.id) = new WeakAcc(a)
localAccums.get()(a.id) = a
}
}

// Clear the local (non-original) accumulators for the current thread
def clear() {
synchronized {
localAccums.get.clear
localAccums.get.clear()
}
}

Expand All @@ -320,12 +329,7 @@ private[spark] object Accumulators {
def values: Map[Long, Any] = synchronized {
val ret = Map[Long, Any]()
for ((id, accum) <- localAccums.get) {
// Since we are now storing weak references, we must check whether the underlying data
// is valid.
ret(id) = accum.get match {
case Some(values) => values.localValue
case None => None
}
ret(id) = accum.localValue
}
return ret
}
Expand All @@ -341,6 +345,8 @@ private[spark] object Accumulators {
case None =>
throw new IllegalAccessError("Attempted to access garbage collected Accumulator.")
}
} else {
logWarning(s"Ignoring accumulator update for unknown accumulator id $id")
}
}
}
Expand Down