Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 26 additions & 38 deletions core/src/main/scala/org/apache/spark/Accumulators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ package org.apache.spark
import java.io.{ObjectInputStream, Serializable}

import scala.collection.generic.Growable
import scala.collection.mutable.Map
import scala.collection.Map
import scala.collection.mutable
import scala.ref.WeakReference
import scala.reflect.ClassTag

Expand All @@ -42,22 +43,37 @@ import org.apache.spark.util.Utils
* @tparam R the full accumulated data (result type)
* @tparam T partial data that can be added in
*/
class Accumulable[R, T] (
class Accumulable[R, T] private[spark] (
@transient initialValue: R,
param: AccumulableParam[R, T],
val name: Option[String])
val name: Option[String],
internal: Boolean)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add internal to the scaladoc param to explain what it is

extends Serializable {

private[spark] def this(
@transient initialValue: R, param: AccumulableParam[R, T], internal: Boolean) = {
this(initialValue, param, None, internal)
}

def this(@transient initialValue: R, param: AccumulableParam[R, T], name: Option[String]) =
this(initialValue, param, name, false)

def this(@transient initialValue: R, param: AccumulableParam[R, T]) =
this(initialValue, param, None)

val id: Long = Accumulators.newId

@transient private var value_ = initialValue // Current value on master
@volatile @transient private var value_ : R = initialValue // Current value on master
val zero = param.zero(initialValue) // Zero value to be passed to workers
private var deserialized = false

Accumulators.register(this, true)
Accumulators.register(this)

/**
* Internal accumulators will be reported via heartbeats. For internal accumulators, `R` must be
* thread safe so that they can be reported correctly.
*/
private[spark] def isInternal: Boolean = internal

/**
* Add more data to this accumulator / accumulable
Expand Down Expand Up @@ -132,7 +148,8 @@ class Accumulable[R, T] (
in.defaultReadObject()
value_ = zero
deserialized = true
Accumulators.register(this, false)
val taskContext = TaskContext.get()
taskContext.registerAccumulator(this)
}

override def toString: String = if (value_ == null) "null" else value_.toString
Expand Down Expand Up @@ -284,16 +301,7 @@ private[spark] object Accumulators extends Logging {
* It keeps weak references to these objects so that accumulators can be garbage-collected
* once the RDDs and user-code that reference them are cleaned up.
*/
val originals = Map[Long, WeakReference[Accumulable[_, _]]]()

/**
* This thread-local map holds per-task copies of accumulators; it is used to collect the set
* of accumulator updates to send back to the driver when tasks complete. After tasks complete,
* this map is cleared by `Accumulators.clear()` (see Executor.scala).
*/
private val localAccums = new ThreadLocal[Map[Long, Accumulable[_, _]]]() {
override protected def initialValue() = Map[Long, Accumulable[_, _]]()
}
val originals = mutable.Map[Long, WeakReference[Accumulable[_, _]]]()

private var lastId: Long = 0

Expand All @@ -302,19 +310,8 @@ private[spark] object Accumulators extends Logging {
lastId
}

def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized {
if (original) {
originals(a.id) = new WeakReference[Accumulable[_, _]](a)
} else {
localAccums.get()(a.id) = a
}
}

// Clear the local (non-original) accumulators for the current thread
def clear() {
synchronized {
localAccums.get.clear()
}
def register(a: Accumulable[_, _]): Unit = synchronized {
originals(a.id) = new WeakReference[Accumulable[_, _]](a)
}

def remove(accId: Long) {
Expand All @@ -323,15 +320,6 @@ private[spark] object Accumulators extends Logging {
}
}

// Get the values of the local accumulators for the current thread (by ID)
def values: Map[Long, Any] = synchronized {
val ret = Map[Long, Any]()
for ((id, accum) <- localAccums.get) {
ret(id) = accum.localValue
}
return ret
}

// Add values to the original accumulators with some given IDs
def add(values: Map[Long, Any]): Unit = synchronized {
for ((id, value) <- values) {
Expand Down
6 changes: 6 additions & 0 deletions core/src/main/scala/org/apache/spark/TaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,10 @@ abstract class TaskContext extends Serializable {
* Returns the manager for this task's managed memory.
*/
private[spark] def taskMemoryManager(): TaskMemoryManager

private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit

private[spark] def collectInternalAccumulators(): Map[Long, Any]

private[spark] def collectAccumulators(): Map[Long, Any]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add scaladoc for these methods?

in particular, for this one does it contain all accumulators, including internal ones?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also make sure you document what the semantics is for the key (long), value (any)

}
19 changes: 16 additions & 3 deletions core/src/main/scala/org/apache/spark/TaskContextImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

package org.apache.spark

import scala.collection.mutable.{ArrayBuffer, HashMap}

import org.apache.spark.executor.TaskMetrics
import org.apache.spark.unsafe.memory.TaskMemoryManager
import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}

import scala.collection.mutable.ArrayBuffer

private[spark] class TaskContextImpl(
val stageId: Int,
val partitionId: Int,
Expand Down Expand Up @@ -94,5 +94,18 @@ private[spark] class TaskContextImpl(
override def isRunningLocally(): Boolean = runningLocally

override def isInterrupted(): Boolean = interrupted
}

@transient private val accumulators = new HashMap[Long, Accumulable[_, _]]

private[spark] override def registerAccumulator(a: Accumulable[_, _]): Unit = synchronized {
accumulators(a.id) = a
}

private[spark] override def collectInternalAccumulators(): Map[Long, Any] = synchronized {
accumulators.filter(_._2.isInternal).mapValues(_.localValue).toMap
}

private[spark] override def collectAccumulators(): Map[Long, Any] = synchronized {
accumulators.mapValues(_.localValue).toMap
}
}
6 changes: 2 additions & 4 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ private[spark] class Executor(

// Run the actual task and measure its runtime.
taskStart = System.currentTimeMillis()
val value = try {
val (value, accumUpdates) = try {
task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
} finally {
// Note: this memory freeing logic is duplicated in DAGScheduler.runLocallyWithinThread;
Expand Down Expand Up @@ -247,7 +247,6 @@ private[spark] class Executor(
m.setResultSerializationTime(afterSerialization - beforeSerialization)
}

val accumUpdates = Accumulators.values
val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull)
val serializedDirectResult = ser.serialize(directResult)
val resultSize = serializedDirectResult.limit
Expand Down Expand Up @@ -314,8 +313,6 @@ private[spark] class Executor(
env.shuffleMemoryManager.releaseMemoryForThisThread()
// Release memory used by this thread for unrolling blocks
env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
// Release memory used by this thread for accumulators
Accumulators.clear()
runningTasks.remove(taskId)
}
}
Expand Down Expand Up @@ -424,6 +421,7 @@ private[spark] class Executor(
metrics.updateShuffleReadMetrics()
metrics.updateInputMetrics()
metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)
metrics.updateAccumulators()

if (isLocal) {
// JobProgressListener will hold an reference of it during
Expand Down
13 changes: 13 additions & 0 deletions core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,19 @@ class TaskMetrics extends Serializable {
// overhead.
_hostname = TaskMetrics.getCachedHostName(_hostname)
}

private var _accumulatorUpdates: Map[Long, Any] = Map.empty
@transient private var _accumulatorsUpdater: () => Map[Long, Any] = null

private[spark] def updateAccumulators(): Unit = synchronized {
_accumulatorUpdates = _accumulatorsUpdater()
}

def accumulatorUpdates(): Map[Long, Any] = _accumulatorUpdates

private[spark] def setAccumulatorsUpdater(accumulatorsUpdater: () => Map[Long, Any]): Unit = {
_accumulatorsUpdater = accumulatorsUpdater
}
}

private[spark] object TaskMetrics {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ import java.util.Properties
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack}
import scala.collection.Map
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Stack}
import scala.concurrent.duration._
import scala.language.existentials
import scala.language.postfixOps
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.scheduler

import java.util.Properties

import scala.collection.mutable.Map
import scala.collection.Map
import scala.language.existentials

import org.apache.spark._
Expand Down
7 changes: 5 additions & 2 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,16 @@ import org.apache.spark.util.Utils
*/
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {

type AccumulatorUpdates = Map[Long, Any]

/**
* Called by [[Executor]] to run this task.
*
* @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext.
* @param attemptNumber how many times this task has been attempted (0 for the first attempt)
* @return the result of the task
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update the documentation here

*/
final def run(taskAttemptId: Long, attemptNumber: Int): T = {
final def run(taskAttemptId: Long, attemptNumber: Int): (T, AccumulatorUpdates) = {
context = new TaskContextImpl(
stageId = stageId,
partitionId = partitionId,
Expand All @@ -62,12 +64,13 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
runningLocally = false)
TaskContext.setTaskContext(context)
context.taskMetrics.setHostname(Utils.localHostName())
context.taskMetrics.setAccumulatorsUpdater(context.collectInternalAccumulators)
taskThread = Thread.currentThread()
if (_killed) {
kill(interruptThread = false)
}
try {
runTask(context)
(runTask(context), context.collectAccumulators())
} finally {
context.markTaskCompleted()
TaskContext.unset()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ package org.apache.spark.scheduler
import java.io._
import java.nio.ByteBuffer

import scala.collection.mutable.Map
import scala.collection.Map
import scala.collection.mutable

import org.apache.spark.SparkEnv
import org.apache.spark.executor.TaskMetrics
Expand Down Expand Up @@ -69,10 +70,11 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long
if (numUpdates == 0) {
accumUpdates = null
} else {
accumUpdates = Map()
val _accumUpdates = mutable.Map[Long, Any]()
for (i <- 0 until numUpdates) {
accumUpdates(in.readLong()) = in.readObject()
_accumUpdates(in.readLong()) = in.readObject()
}
accumUpdates = _accumUpdates
}
metrics = in.readObject().asInstanceOf[TaskMetrics]
valueObjectDeserialized = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ package org.apache.spark.scheduler

import java.util.Random

import scala.collection.mutable.ArrayBuffer
import scala.collection.Map
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.util.{ManualClock, Utils}
import org.apache.spark.util.ManualClock

class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler)
extends DAGScheduler(sc) {
Expand All @@ -37,7 +38,7 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler)
task: Task[_],
reason: TaskEndReason,
result: Any,
accumUpdates: mutable.Map[Long, Any],
accumUpdates: Map[Long, Any],
taskInfo: TaskInfo,
taskMetrics: TaskMetrics) {
taskScheduler.endedTasks(taskInfo.index) = reason
Expand Down