Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 812b63b

Browse files
zsxwingrxin
authored andcommitted
[SPARK-8857][SPARK-8859][Core]Add an internal flag to Accumulable and send internal accumulator updates to the driver via heartbeats
This PR includes the following changes: 1. Remove the thread local `Accumulators.localAccums`. Instead, all Accumulators in the executors will register with its TaskContext. 2. Add an internal flag to Accumulable. For internal Accumulators, their updates will be sent to the driver via heartbeats. Author: zsxwing <[email protected]> Closes apache#7448 from zsxwing/accumulators and squashes the following commits: c24bc5b [zsxwing] Add comments bd7dcf1 [zsxwing] Add an internal flag to Accumulable and send internal accumulator updates to the driver via heartbeats
1 parent 96aa334 commit 812b63b

File tree

10 files changed

+104
-56
lines changed

10 files changed

+104
-56
lines changed

core/src/main/scala/org/apache/spark/Accumulators.scala

Lines changed: 30 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ package org.apache.spark
2020
import java.io.{ObjectInputStream, Serializable}
2121

2222
import scala.collection.generic.Growable
23-
import scala.collection.mutable.Map
23+
import scala.collection.Map
24+
import scala.collection.mutable
2425
import scala.ref.WeakReference
2526
import scala.reflect.ClassTag
2627

@@ -39,25 +40,44 @@ import org.apache.spark.util.Utils
3940
* @param initialValue initial value of accumulator
4041
* @param param helper object defining how to add elements of type `R` and `T`
4142
* @param name human-readable name for use in Spark's web UI
43+
* @param internal if this [[Accumulable]] is internal. Internal [[Accumulable]]s will be reported
44+
* to the driver via heartbeats. For internal [[Accumulable]]s, `R` must be
45+
* thread safe so that they can be reported correctly.
4246
* @tparam R the full accumulated data (result type)
4347
* @tparam T partial data that can be added in
4448
*/
45-
class Accumulable[R, T] (
49+
class Accumulable[R, T] private[spark] (
4650
@transient initialValue: R,
4751
param: AccumulableParam[R, T],
48-
val name: Option[String])
52+
val name: Option[String],
53+
internal: Boolean)
4954
extends Serializable {
5055

56+
private[spark] def this(
57+
@transient initialValue: R, param: AccumulableParam[R, T], internal: Boolean) = {
58+
this(initialValue, param, None, internal)
59+
}
60+
61+
def this(@transient initialValue: R, param: AccumulableParam[R, T], name: Option[String]) =
62+
this(initialValue, param, name, false)
63+
5164
def this(@transient initialValue: R, param: AccumulableParam[R, T]) =
5265
this(initialValue, param, None)
5366

5467
val id: Long = Accumulators.newId
5568

56-
@transient private var value_ = initialValue // Current value on master
69+
@volatile @transient private var value_ : R = initialValue // Current value on master
5770
val zero = param.zero(initialValue) // Zero value to be passed to workers
5871
private var deserialized = false
5972

60-
Accumulators.register(this, true)
73+
Accumulators.register(this)
74+
75+
/**
76+
* If this [[Accumulable]] is internal. Internal [[Accumulable]]s will be reported to the driver
77+
* via heartbeats. For internal [[Accumulable]]s, `R` must be thread safe so that they can be
78+
* reported correctly.
79+
*/
80+
private[spark] def isInternal: Boolean = internal
6181

6282
/**
6383
* Add more data to this accumulator / accumulable
@@ -132,7 +152,8 @@ class Accumulable[R, T] (
132152
in.defaultReadObject()
133153
value_ = zero
134154
deserialized = true
135-
Accumulators.register(this, false)
155+
val taskContext = TaskContext.get()
156+
taskContext.registerAccumulator(this)
136157
}
137158

138159
override def toString: String = if (value_ == null) "null" else value_.toString
@@ -284,16 +305,7 @@ private[spark] object Accumulators extends Logging {
284305
* It keeps weak references to these objects so that accumulators can be garbage-collected
285306
* once the RDDs and user-code that reference them are cleaned up.
286307
*/
287-
val originals = Map[Long, WeakReference[Accumulable[_, _]]]()
288-
289-
/**
290-
* This thread-local map holds per-task copies of accumulators; it is used to collect the set
291-
* of accumulator updates to send back to the driver when tasks complete. After tasks complete,
292-
* this map is cleared by `Accumulators.clear()` (see Executor.scala).
293-
*/
294-
private val localAccums = new ThreadLocal[Map[Long, Accumulable[_, _]]]() {
295-
override protected def initialValue() = Map[Long, Accumulable[_, _]]()
296-
}
308+
val originals = mutable.Map[Long, WeakReference[Accumulable[_, _]]]()
297309

298310
private var lastId: Long = 0
299311

@@ -302,19 +314,8 @@ private[spark] object Accumulators extends Logging {
302314
lastId
303315
}
304316

305-
def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized {
306-
if (original) {
307-
originals(a.id) = new WeakReference[Accumulable[_, _]](a)
308-
} else {
309-
localAccums.get()(a.id) = a
310-
}
311-
}
312-
313-
// Clear the local (non-original) accumulators for the current thread
314-
def clear() {
315-
synchronized {
316-
localAccums.get.clear()
317-
}
317+
def register(a: Accumulable[_, _]): Unit = synchronized {
318+
originals(a.id) = new WeakReference[Accumulable[_, _]](a)
318319
}
319320

320321
def remove(accId: Long) {
@@ -323,15 +324,6 @@ private[spark] object Accumulators extends Logging {
323324
}
324325
}
325326

326-
// Get the values of the local accumulators for the current thread (by ID)
327-
def values: Map[Long, Any] = synchronized {
328-
val ret = Map[Long, Any]()
329-
for ((id, accum) <- localAccums.get) {
330-
ret(id) = accum.localValue
331-
}
332-
return ret
333-
}
334-
335327
// Add values to the original accumulators with some given IDs
336328
def add(values: Map[Long, Any]): Unit = synchronized {
337329
for ((id, value) <- values) {

core/src/main/scala/org/apache/spark/TaskContext.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,4 +152,22 @@ abstract class TaskContext extends Serializable {
152152
* Returns the manager for this task's managed memory.
153153
*/
154154
private[spark] def taskMemoryManager(): TaskMemoryManager
155+
156+
/**
157+
* Register an accumulator that belongs to this task. Accumulators must call this method when
158+
* deserializing in executors.
159+
*/
160+
private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit
161+
162+
/**
163+
* Return the local values of internal accumulators that belong to this task. The key of the Map
164+
* is the accumulator id and the value of the Map is the latest accumulator local value.
165+
*/
166+
private[spark] def collectInternalAccumulators(): Map[Long, Any]
167+
168+
/**
169+
* Return the local values of accumulators that belong to this task. The key of the Map is the
170+
* accumulator id and the value of the Map is the latest accumulator local value.
171+
*/
172+
private[spark] def collectAccumulators(): Map[Long, Any]
155173
}

core/src/main/scala/org/apache/spark/TaskContextImpl.scala

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717

1818
package org.apache.spark
1919

20+
import scala.collection.mutable.{ArrayBuffer, HashMap}
21+
2022
import org.apache.spark.executor.TaskMetrics
2123
import org.apache.spark.unsafe.memory.TaskMemoryManager
2224
import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException}
2325

24-
import scala.collection.mutable.ArrayBuffer
25-
2626
private[spark] class TaskContextImpl(
2727
val stageId: Int,
2828
val partitionId: Int,
@@ -94,5 +94,18 @@ private[spark] class TaskContextImpl(
9494
override def isRunningLocally(): Boolean = runningLocally
9595

9696
override def isInterrupted(): Boolean = interrupted
97-
}
9897

98+
@transient private val accumulators = new HashMap[Long, Accumulable[_, _]]
99+
100+
private[spark] override def registerAccumulator(a: Accumulable[_, _]): Unit = synchronized {
101+
accumulators(a.id) = a
102+
}
103+
104+
private[spark] override def collectInternalAccumulators(): Map[Long, Any] = synchronized {
105+
accumulators.filter(_._2.isInternal).mapValues(_.localValue).toMap
106+
}
107+
108+
private[spark] override def collectAccumulators(): Map[Long, Any] = synchronized {
109+
accumulators.mapValues(_.localValue).toMap
110+
}
111+
}

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ private[spark] class Executor(
209209

210210
// Run the actual task and measure its runtime.
211211
taskStart = System.currentTimeMillis()
212-
val value = try {
212+
val (value, accumUpdates) = try {
213213
task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
214214
} finally {
215215
// Note: this memory freeing logic is duplicated in DAGScheduler.runLocallyWithinThread;
@@ -247,7 +247,6 @@ private[spark] class Executor(
247247
m.setResultSerializationTime(afterSerialization - beforeSerialization)
248248
}
249249

250-
val accumUpdates = Accumulators.values
251250
val directResult = new DirectTaskResult(valueBytes, accumUpdates, task.metrics.orNull)
252251
val serializedDirectResult = ser.serialize(directResult)
253252
val resultSize = serializedDirectResult.limit
@@ -314,8 +313,6 @@ private[spark] class Executor(
314313
env.shuffleMemoryManager.releaseMemoryForThisThread()
315314
// Release memory used by this thread for unrolling blocks
316315
env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
317-
// Release memory used by this thread for accumulators
318-
Accumulators.clear()
319316
runningTasks.remove(taskId)
320317
}
321318
}
@@ -424,6 +421,7 @@ private[spark] class Executor(
424421
metrics.updateShuffleReadMetrics()
425422
metrics.updateInputMetrics()
426423
metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)
424+
metrics.updateAccumulators()
427425

428426
if (isLocal) {
429427
// JobProgressListener will hold an reference of it during

core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,22 @@ class TaskMetrics extends Serializable {
223223
// overhead.
224224
_hostname = TaskMetrics.getCachedHostName(_hostname)
225225
}
226+
227+
private var _accumulatorUpdates: Map[Long, Any] = Map.empty
228+
@transient private var _accumulatorsUpdater: () => Map[Long, Any] = null
229+
230+
private[spark] def updateAccumulators(): Unit = synchronized {
231+
_accumulatorUpdates = _accumulatorsUpdater()
232+
}
233+
234+
/**
235+
* Return the latest updates of accumulators in this task.
236+
*/
237+
def accumulatorUpdates(): Map[Long, Any] = _accumulatorUpdates
238+
239+
private[spark] def setAccumulatorsUpdater(accumulatorsUpdater: () => Map[Long, Any]): Unit = {
240+
_accumulatorsUpdater = accumulatorsUpdater
241+
}
226242
}
227243

228244
private[spark] object TaskMetrics {

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ import java.util.Properties
2222
import java.util.concurrent.TimeUnit
2323
import java.util.concurrent.atomic.AtomicInteger
2424

25-
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack}
25+
import scala.collection.Map
26+
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Stack}
2627
import scala.concurrent.duration._
2728
import scala.language.existentials
2829
import scala.language.postfixOps

core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.scheduler
1919

2020
import java.util.Properties
2121

22-
import scala.collection.mutable.Map
22+
import scala.collection.Map
2323
import scala.language.existentials
2424

2525
import org.apache.spark._

core/src/main/scala/org/apache/spark/scheduler/Task.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,20 @@ import org.apache.spark.util.Utils
4545
*/
4646
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
4747

48+
/**
49+
* The key of the Map is the accumulator id and the value of the Map is the latest accumulator
50+
* local value.
51+
*/
52+
type AccumulatorUpdates = Map[Long, Any]
53+
4854
/**
4955
* Called by [[Executor]] to run this task.
5056
*
5157
* @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext.
5258
* @param attemptNumber how many times this task has been attempted (0 for the first attempt)
53-
* @return the result of the task
59+
* @return the result of the task along with updates of Accumulators.
5460
*/
55-
final def run(taskAttemptId: Long, attemptNumber: Int): T = {
61+
final def run(taskAttemptId: Long, attemptNumber: Int): (T, AccumulatorUpdates) = {
5662
context = new TaskContextImpl(
5763
stageId = stageId,
5864
partitionId = partitionId,
@@ -62,12 +68,13 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
6268
runningLocally = false)
6369
TaskContext.setTaskContext(context)
6470
context.taskMetrics.setHostname(Utils.localHostName())
71+
context.taskMetrics.setAccumulatorsUpdater(context.collectInternalAccumulators)
6572
taskThread = Thread.currentThread()
6673
if (_killed) {
6774
kill(interruptThread = false)
6875
}
6976
try {
70-
runTask(context)
77+
(runTask(context), context.collectAccumulators())
7178
} finally {
7279
context.markTaskCompleted()
7380
TaskContext.unset()

core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ package org.apache.spark.scheduler
2020
import java.io._
2121
import java.nio.ByteBuffer
2222

23-
import scala.collection.mutable.Map
23+
import scala.collection.Map
24+
import scala.collection.mutable
2425

2526
import org.apache.spark.SparkEnv
2627
import org.apache.spark.executor.TaskMetrics
@@ -69,10 +70,11 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long
6970
if (numUpdates == 0) {
7071
accumUpdates = null
7172
} else {
72-
accumUpdates = Map()
73+
val _accumUpdates = mutable.Map[Long, Any]()
7374
for (i <- 0 until numUpdates) {
74-
accumUpdates(in.readLong()) = in.readObject()
75+
_accumUpdates(in.readLong()) = in.readObject()
7576
}
77+
accumUpdates = _accumUpdates
7678
}
7779
metrics = in.readObject().asInstanceOf[TaskMetrics]
7880
valueObjectDeserialized = false

core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@ package org.apache.spark.scheduler
1919

2020
import java.util.Random
2121

22-
import scala.collection.mutable.ArrayBuffer
22+
import scala.collection.Map
2323
import scala.collection.mutable
24+
import scala.collection.mutable.ArrayBuffer
2425

2526
import org.apache.spark._
2627
import org.apache.spark.executor.TaskMetrics
27-
import org.apache.spark.util.{ManualClock, Utils}
28+
import org.apache.spark.util.ManualClock
2829

2930
class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler)
3031
extends DAGScheduler(sc) {
@@ -37,7 +38,7 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler)
3738
task: Task[_],
3839
reason: TaskEndReason,
3940
result: Any,
40-
accumUpdates: mutable.Map[Long, Any],
41+
accumUpdates: Map[Long, Any],
4142
taskInfo: TaskInfo,
4243
taskMetrics: TaskMetrics) {
4344
taskScheduler.endedTasks(taskInfo.index) = reason

0 commit comments

Comments
 (0)