Skip to content

Commit b90bf52

Browse files
cloud-fanhvanhovell
authored andcommitted
[SPARK-12837][CORE] Do not send the name of internal accumulator to executor side
## What changes were proposed in this pull request? When sending accumulator updates back to driver, the network overhead is pretty big as there are a lot of accumulators, e.g. `TaskMetrics` will send about 20 accumulators everytime, there may be a lot of `SQLMetric` if the query plan is complicated. Therefore, it's critical to reduce the size of serialized accumulator. A simple way is to not send the name of internal accumulators to executor side, as it's unnecessary. When executor sends accumulator updates back to driver, we can look up the accumulator name in `AccumulatorContext` easily. Note that, we still need to send names of normal accumulators, as the user code run at executor side may rely on accumulator names. In the future, we should reimplement `TaskMetrics` to not rely on accumulators and use custom serialization. Tried on the example in https://issues.apache.org/jira/browse/SPARK-12837, the size of serialized accumulator has been cut down by about 40%. ## How was this patch tested? existing tests. Author: Wenchen Fan <[email protected]> Closes #17596 from cloud-fan/oom.
1 parent 823baca commit b90bf52

File tree

8 files changed

+76
-54
lines changed

8 files changed

+76
-54
lines changed

core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -251,13 +251,10 @@ class TaskMetrics private[spark] () extends Serializable {
251251

252252
private[spark] def accumulators(): Seq[AccumulatorV2[_, _]] = internalAccums ++ externalAccums
253253

254-
/**
255-
* Looks for a registered accumulator by accumulator name.
256-
*/
257-
private[spark] def lookForAccumulatorByName(name: String): Option[AccumulatorV2[_, _]] = {
258-
accumulators.find { acc =>
259-
acc.name.isDefined && acc.name.get == name
260-
}
254+
private[spark] def nonZeroInternalAccums(): Seq[AccumulatorV2[_, _]] = {
255+
// RESULT_SIZE accumulator is always zero at executor, we need to send it back as its
256+
// value will be updated at driver side.
257+
internalAccums.filter(a => !a.isZero || a == _resultSize)
261258
}
262259
}
263260

@@ -308,16 +305,16 @@ private[spark] object TaskMetrics extends Logging {
308305
*/
309306
def fromAccumulators(accums: Seq[AccumulatorV2[_, _]]): TaskMetrics = {
310307
val tm = new TaskMetrics
311-
val (internalAccums, externalAccums) =
312-
accums.partition(a => a.name.isDefined && tm.nameToAccums.contains(a.name.get))
313-
314-
internalAccums.foreach { acc =>
315-
val tmAcc = tm.nameToAccums(acc.name.get).asInstanceOf[AccumulatorV2[Any, Any]]
316-
tmAcc.metadata = acc.metadata
317-
tmAcc.merge(acc.asInstanceOf[AccumulatorV2[Any, Any]])
308+
for (acc <- accums) {
309+
val name = acc.name
310+
if (name.isDefined && tm.nameToAccums.contains(name.get)) {
311+
val tmAcc = tm.nameToAccums(name.get).asInstanceOf[AccumulatorV2[Any, Any]]
312+
tmAcc.metadata = acc.metadata
313+
tmAcc.merge(acc.asInstanceOf[AccumulatorV2[Any, Any]])
314+
} else {
315+
tm.externalAccums += acc
316+
}
318317
}
319-
320-
tm.externalAccums ++= externalAccums
321318
tm
322319
}
323320
}

core/src/main/scala/org/apache/spark/scheduler/Task.scala

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -182,14 +182,11 @@ private[spark] abstract class Task[T](
182182
*/
183183
def collectAccumulatorUpdates(taskFailed: Boolean = false): Seq[AccumulatorV2[_, _]] = {
184184
if (context != null) {
185-
context.taskMetrics.internalAccums.filter { a =>
186-
// RESULT_SIZE accumulator is always zero at executor, we need to send it back as its
187-
// value will be updated at driver side.
188-
// Note: internal accumulators representing task metrics always count failed values
189-
!a.isZero || a.name == Some(InternalAccumulator.RESULT_SIZE)
190-
// zero value external accumulators may still be useful, e.g. SQLMetrics, we should not filter
191-
// them out.
192-
} ++ context.taskMetrics.externalAccums.filter(a => !taskFailed || a.countFailedValues)
185+
// Note: internal accumulators representing task metrics always count failed values
186+
context.taskMetrics.nonZeroInternalAccums() ++
187+
// zero value external accumulators may still be useful, e.g. SQLMetrics, we should not
188+
// filter them out.
189+
context.taskMetrics.externalAccums.filter(a => !taskFailed || a.countFailedValues)
193190
} else {
194191
Seq.empty
195192
}

core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,12 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable {
8484
* Returns the name of this accumulator, can only be called after registration.
8585
*/
8686
final def name: Option[String] = {
87-
assertMetadataNotNull()
88-
metadata.name
87+
if (atDriverSide) {
88+
AccumulatorContext.get(id).flatMap(_.metadata.name)
89+
} else {
90+
assertMetadataNotNull()
91+
metadata.name
92+
}
8993
}
9094

9195
/**
@@ -161,7 +165,15 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable {
161165
}
162166
val copyAcc = copyAndReset()
163167
assert(copyAcc.isZero, "copyAndReset must return a zero value copy")
164-
copyAcc.metadata = metadata
168+
val isInternalAcc =
169+
(name.isDefined && name.get.startsWith(InternalAccumulator.METRICS_PREFIX)) ||
170+
getClass.getSimpleName == "SQLMetric"
171+
if (isInternalAcc) {
172+
// Do not serialize the name of internal accumulator and send it to executor.
173+
copyAcc.metadata = metadata.copy(name = None)
174+
} else {
175+
copyAcc.metadata = metadata
176+
}
165177
copyAcc
166178
} else {
167179
this
@@ -263,16 +275,6 @@ private[spark] object AccumulatorContext {
263275
originals.clear()
264276
}
265277

266-
/**
267-
* Looks for a registered accumulator by accumulator name.
268-
*/
269-
private[spark] def lookForAccumulatorByName(name: String): Option[AccumulatorV2[_, _]] = {
270-
originals.values().asScala.find { ref =>
271-
val acc = ref.get
272-
acc != null && acc.name.isDefined && acc.name.get == name
273-
}.map(_.get)
274-
}
275-
276278
// Identifier for distinguishing SQL metrics from other accumulators
277279
private[spark] val SQL_ACCUM_IDENTIFIER = "sql"
278280
}

core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
198198
sc = new SparkContext("local", "test")
199199
// Create a dummy task. We won't end up running this; we just want to collect
200200
// accumulator updates from it.
201-
val taskMetrics = TaskMetrics.empty
201+
val taskMetrics = TaskMetrics.registered
202202
val task = new Task[Int](0, 0, 0) {
203203
context = new TaskContextImpl(0, 0, 0L, 0,
204204
new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),

core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
293293
val execId = "exe-1"
294294

295295
def makeTaskMetrics(base: Int): TaskMetrics = {
296-
val taskMetrics = TaskMetrics.empty
296+
val taskMetrics = TaskMetrics.registered
297297
val shuffleReadMetrics = taskMetrics.createTempShuffleReadMetrics()
298298
val shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics
299299
val inputMetrics = taskMetrics.inputMetrics

core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
830830
hasHadoopInput: Boolean,
831831
hasOutput: Boolean,
832832
hasRecords: Boolean = true) = {
833-
val t = TaskMetrics.empty
833+
val t = TaskMetrics.registered
834834
// Set CPU times same as wall times for testing purpose
835835
t.setExecutorDeserializeTime(a)
836836
t.setExecutorDeserializeCpuTime(a)

sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,14 +153,14 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont
153153
}
154154

155155
// For test purpose.
156-
// If the predefined accumulator exists, the row group number to read will be updated
157-
// to the accumulator. So we can check if the row groups are filtered or not in test case.
156+
// If the last external accumulator is `NumRowGroupsAccumulator`, the row group number to read
157+
// will be updated to the accumulator. So we can check if the row groups are filtered or not
158+
// in test case.
158159
TaskContext taskContext = TaskContext$.MODULE$.get();
159160
if (taskContext != null) {
160-
Option<AccumulatorV2<?, ?>> accu = taskContext.taskMetrics()
161-
.lookForAccumulatorByName("numRowGroups");
162-
if (accu.isDefined()) {
163-
((LongAccumulator)accu.get()).add((long)blocks.size());
161+
Option<AccumulatorV2<?, ?>> accu = taskContext.taskMetrics().externalAccums().lastOption();
162+
if (accu.isDefined() && accu.get().getClass().getSimpleName().equals("NumRowGroupsAcc")) {
163+
((AccumulatorV2<Integer, Integer>)accu.get()).add(blocks.size());
164164
}
165165
}
166166
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.sql.functions._
3232
import org.apache.spark.sql.internal.SQLConf
3333
import org.apache.spark.sql.test.SharedSQLContext
3434
import org.apache.spark.sql.types._
35-
import org.apache.spark.util.{AccumulatorContext, LongAccumulator}
35+
import org.apache.spark.util.{AccumulatorContext, AccumulatorV2}
3636

3737
/**
3838
* A test suite that tests Parquet filter2 API based filter pushdown optimization.
@@ -499,18 +499,20 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
499499
val path = s"${dir.getCanonicalPath}/table"
500500
(1 to 1024).map(i => (101, i)).toDF("a", "b").write.parquet(path)
501501

502-
Seq(("true", (x: Long) => x == 0), ("false", (x: Long) => x > 0)).map { case (push, func) =>
503-
withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> push) {
504-
val accu = new LongAccumulator
505-
accu.register(sparkContext, Some("numRowGroups"))
502+
Seq(true, false).foreach { enablePushDown =>
503+
withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> enablePushDown.toString) {
504+
val accu = new NumRowGroupsAcc
505+
sparkContext.register(accu)
506506

507507
val df = spark.read.parquet(path).filter("a < 100")
508508
df.foreachPartition(_.foreach(v => accu.add(0)))
509509
df.collect
510510

511-
val numRowGroups = AccumulatorContext.lookForAccumulatorByName("numRowGroups")
512-
assert(numRowGroups.isDefined)
513-
assert(func(numRowGroups.get.asInstanceOf[LongAccumulator].value))
511+
if (enablePushDown) {
512+
assert(accu.value == 0)
513+
} else {
514+
assert(accu.value > 0)
515+
}
514516
AccumulatorContext.remove(accu.id)
515517
}
516518
}
@@ -537,3 +539,27 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
537539
}
538540
}
539541
}
542+
543+
class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] {
544+
private var _sum = 0
545+
546+
override def isZero: Boolean = _sum == 0
547+
548+
override def copy(): AccumulatorV2[Integer, Integer] = {
549+
val acc = new NumRowGroupsAcc()
550+
acc._sum = _sum
551+
acc
552+
}
553+
554+
override def reset(): Unit = _sum = 0
555+
556+
override def add(v: Integer): Unit = _sum += v
557+
558+
override def merge(other: AccumulatorV2[Integer, Integer]): Unit = other match {
559+
case a: NumRowGroupsAcc => _sum += a._sum
560+
case _ => throw new UnsupportedOperationException(
561+
s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
562+
}
563+
564+
override def value: Integer = _sum
565+
}

0 commit comments

Comments
 (0)