Skip to content

Commit 0616716

Browse files
committed
Do not send the accumulator name to executor side
1 parent cd91f96 commit 0616716

File tree

8 files changed

+79
-55
lines changed

8 files changed

+79
-55
lines changed

core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -251,13 +251,10 @@ class TaskMetrics private[spark] () extends Serializable {
251251

252252
private[spark] def accumulators(): Seq[AccumulatorV2[_, _]] = internalAccums ++ externalAccums
253253

254-
/**
255-
* Looks for a registered accumulator by accumulator name.
256-
*/
257-
private[spark] def lookForAccumulatorByName(name: String): Option[AccumulatorV2[_, _]] = {
258-
accumulators.find { acc =>
259-
acc.name.isDefined && acc.name.get == name
260-
}
254+
private[spark] def nonZeroInternalAccums(): Seq[AccumulatorV2[_, _]] = {
255+
// RESULT_SIZE accumulator is always zero at executor, we need to send it back as its
256+
// value will be updated at driver side.
257+
internalAccums.filter(a => !a.isZero || a == _resultSize)
261258
}
262259
}
263260

@@ -308,16 +305,17 @@ private[spark] object TaskMetrics extends Logging {
308305
*/
309306
def fromAccumulators(accums: Seq[AccumulatorV2[_, _]]): TaskMetrics = {
310307
val tm = new TaskMetrics
311-
val (internalAccums, externalAccums) =
312-
accums.partition(a => a.name.isDefined && tm.nameToAccums.contains(a.name.get))
313-
314-
internalAccums.foreach { acc =>
315-
val tmAcc = tm.nameToAccums(acc.name.get).asInstanceOf[AccumulatorV2[Any, Any]]
316-
tmAcc.metadata = acc.metadata
317-
tmAcc.merge(acc.asInstanceOf[AccumulatorV2[Any, Any]])
308+
for (acc <- accums) {
309+
val name = AccumulatorContext.get(acc.id).flatMap(_.name)
310+
if (name.isDefined && tm.nameToAccums.contains(name.get)) {
311+
val tmAcc = tm.nameToAccums(name.get).asInstanceOf[AccumulatorV2[Any, Any]]
312+
tmAcc.metadata = acc.metadata
313+
tmAcc.merge(acc.asInstanceOf[AccumulatorV2[Any, Any]])
314+
} else {
315+
acc.metadata = acc.metadata.copy(name = name)
316+
tm.externalAccums += acc
317+
}
318318
}
319-
320-
tm.externalAccums ++= externalAccums
321319
tm
322320
}
323321
}

core/src/main/scala/org/apache/spark/scheduler/Task.scala

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -182,14 +182,11 @@ private[spark] abstract class Task[T](
182182
*/
183183
def collectAccumulatorUpdates(taskFailed: Boolean = false): Seq[AccumulatorV2[_, _]] = {
184184
if (context != null) {
185-
context.taskMetrics.internalAccums.filter { a =>
186-
// RESULT_SIZE accumulator is always zero at executor, we need to send it back as its
187-
// value will be updated at driver side.
188-
// Note: internal accumulators representing task metrics always count failed values
189-
!a.isZero || a.name == Some(InternalAccumulator.RESULT_SIZE)
190-
// zero value external accumulators may still be useful, e.g. SQLMetrics, we should not filter
191-
// them out.
192-
} ++ context.taskMetrics.externalAccums.filter(a => !taskFailed || a.countFailedValues)
185+
// Note: internal accumulators representing task metrics always count failed values
186+
context.taskMetrics.nonZeroInternalAccums() ++
187+
// zero value external accumulators may still be useful, e.g. SQLMetrics, we should not
188+
// filter them out.
189+
context.taskMetrics.externalAccums.filter(a => !taskFailed || a.countFailedValues)
193190
} else {
194191
Seq.empty
195192
}

core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark._
2727
import org.apache.spark.TaskState.TaskState
2828
import org.apache.spark.internal.Logging
2929
import org.apache.spark.serializer.SerializerInstance
30-
import org.apache.spark.util.{LongAccumulator, ThreadUtils, Utils}
30+
import org.apache.spark.util.{AccumulatorContext, LongAccumulator, ThreadUtils, Utils}
3131

3232
/**
3333
* Runs a thread pool that deserializes and remotely fetches (if necessary) task results.
@@ -100,7 +100,8 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
100100
// We need to do this here on the driver because if we did this on the executors then
101101
// we would have to serialize the result again after updating the size.
102102
result.accumUpdates = result.accumUpdates.map { a =>
103-
if (a.name == Some(InternalAccumulator.RESULT_SIZE)) {
103+
val accName = AccumulatorContext.get(a.id).flatMap(_.name)
104+
if (accName == Some(InternalAccumulator.RESULT_SIZE)) {
104105
val acc = a.asInstanceOf[LongAccumulator]
105106
assert(acc.sum == 0L, "task result size should not have been set on the executors")
106107
acc.setValue(size.toLong)

core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,15 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable {
161161
}
162162
val copyAcc = copyAndReset()
163163
assert(copyAcc.isZero, "copyAndReset must return a zero value copy")
164-
copyAcc.metadata = metadata
164+
val isInternalAcc =
165+
(name.isDefined && name.get.startsWith(InternalAccumulator.METRICS_PREFIX)) ||
166+
getClass.getSimpleName == "SQLMetric"
167+
if (isInternalAcc) {
168+
// Do not serialize the name of internal accumulator and send it to executor.
169+
copyAcc.metadata = metadata.copy(name = None)
170+
} else {
171+
copyAcc.metadata = metadata
172+
}
165173
copyAcc
166174
} else {
167175
this
@@ -263,16 +271,6 @@ private[spark] object AccumulatorContext {
263271
originals.clear()
264272
}
265273

266-
/**
267-
* Looks for a registered accumulator by accumulator name.
268-
*/
269-
private[spark] def lookForAccumulatorByName(name: String): Option[AccumulatorV2[_, _]] = {
270-
originals.values().asScala.find { ref =>
271-
val acc = ref.get
272-
acc != null && acc.name.isDefined && acc.name.get == name
273-
}.map(_.get)
274-
}
275-
276274
// Identifier for distinguishing SQL metrics from other accumulators
277275
private[spark] val SQL_ACCUM_IDENTIFIER = "sql"
278276
}

core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
198198
sc = new SparkContext("local", "test")
199199
// Create a dummy task. We won't end up running this; we just want to collect
200200
// accumulator updates from it.
201-
val taskMetrics = TaskMetrics.empty
201+
val taskMetrics = TaskMetrics.registered
202202
val task = new Task[Int](0, 0, 0) {
203203
context = new TaskContextImpl(0, 0, 0L, 0,
204204
new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),

core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import org.scalatest.concurrent.Eventually._
3636
import org.apache.spark._
3737
import org.apache.spark.storage.TaskResultBlockId
3838
import org.apache.spark.TestUtils.JavaSourceFromString
39-
import org.apache.spark.util.{MutableURLClassLoader, RpcUtils, Utils}
39+
import org.apache.spark.util.{AccumulatorContext, MutableURLClassLoader, RpcUtils, Utils}
4040

4141

4242
/**
@@ -242,8 +242,12 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local
242242
assert(resultGetter.taskResults.size === 1)
243243
val resBefore = resultGetter.taskResults.head
244244
val resAfter = captor.getValue
245-
val resSizeBefore = resBefore.accumUpdates.find(_.name == Some(RESULT_SIZE)).map(_.value)
246-
val resSizeAfter = resAfter.accumUpdates.find(_.name == Some(RESULT_SIZE)).map(_.value)
245+
val resSizeBefore = resBefore.accumUpdates.find { acc =>
246+
AccumulatorContext.get(acc.id).flatMap(_.name) == Some(RESULT_SIZE)
247+
}.map(_.value)
248+
val resSizeAfter = resAfter.accumUpdates.find { acc =>
249+
AccumulatorContext.get(acc.id).flatMap(_.name) == Some(RESULT_SIZE)
250+
}.map(_.value)
247251
assert(resSizeBefore.exists(_ == 0L))
248252
assert(resSizeAfter.exists(_.toString.toLong > 0L))
249253
}

sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,14 +153,14 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont
153153
}
154154

155155
// For test purpose.
156-
// If the predefined accumulator exists, the row group number to read will be updated
157-
// to the accumulator. So we can check if the row groups are filtered or not in test case.
156+
// If the last external accumulator is `NumRowGroupsAccumulator`, the row group number to read
157+
// will be updated to the accumulator. So we can check if the row groups are filtered or not
158+
// in test case.
158159
TaskContext taskContext = TaskContext$.MODULE$.get();
159160
if (taskContext != null) {
160-
Option<AccumulatorV2<?, ?>> accu = taskContext.taskMetrics()
161-
.lookForAccumulatorByName("numRowGroups");
162-
if (accu.isDefined()) {
163-
((LongAccumulator)accu.get()).add((long)blocks.size());
161+
Option<AccumulatorV2<?, ?>> accu = taskContext.taskMetrics().externalAccums().lastOption();
162+
if (accu.isDefined() && accu.get().getClass().getSimpleName().equals("NumRowGroupsAcc")) {
163+
((AccumulatorV2<Integer, Integer>)accu.get()).add(blocks.size());
164164
}
165165
}
166166
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import org.apache.spark.sql.functions._
3232
import org.apache.spark.sql.internal.SQLConf
3333
import org.apache.spark.sql.test.SharedSQLContext
3434
import org.apache.spark.sql.types._
35-
import org.apache.spark.util.{AccumulatorContext, LongAccumulator}
35+
import org.apache.spark.util.{AccumulatorContext, AccumulatorV2}
3636

3737
/**
3838
* A test suite that tests Parquet filter2 API based filter pushdown optimization.
@@ -499,18 +499,20 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
499499
val path = s"${dir.getCanonicalPath}/table"
500500
(1 to 1024).map(i => (101, i)).toDF("a", "b").write.parquet(path)
501501

502-
Seq(("true", (x: Long) => x == 0), ("false", (x: Long) => x > 0)).map { case (push, func) =>
503-
withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> push) {
504-
val accu = new LongAccumulator
505-
accu.register(sparkContext, Some("numRowGroups"))
502+
Seq(true, false).foreach { enablePushDown =>
503+
withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> enablePushDown.toString) {
504+
val accu = new NumRowGroupsAcc
505+
sparkContext.register(accu)
506506

507507
val df = spark.read.parquet(path).filter("a < 100")
508508
df.foreachPartition(_.foreach(v => accu.add(0)))
509509
df.collect
510510

511-
val numRowGroups = AccumulatorContext.lookForAccumulatorByName("numRowGroups")
512-
assert(numRowGroups.isDefined)
513-
assert(func(numRowGroups.get.asInstanceOf[LongAccumulator].value))
511+
if (enablePushDown) {
512+
assert(accu.value == 0)
513+
} else {
514+
assert(accu.value > 0)
515+
}
514516
AccumulatorContext.remove(accu.id)
515517
}
516518
}
@@ -537,3 +539,27 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
537539
}
538540
}
539541
}
542+
543+
class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] {
544+
private var _sum = 0
545+
546+
override def isZero: Boolean = _sum == 0
547+
548+
override def copy(): AccumulatorV2[Integer, Integer] = {
549+
val acc = new NumRowGroupsAcc()
550+
acc._sum = _sum
551+
acc
552+
}
553+
554+
override def reset(): Unit = _sum = 0
555+
556+
override def add(v: Integer): Unit = _sum += v
557+
558+
override def merge(other: AccumulatorV2[Integer, Integer]): Unit = other match {
559+
case a: NumRowGroupsAcc => _sum += a._sum
560+
case _ => throw new UnsupportedOperationException(
561+
s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
562+
}
563+
564+
override def value: Integer = _sum
565+
}

0 commit comments

Comments
 (0)