Skip to content

Commit 4fcee4f

Browse files
committed
address comments
1 parent 0616716 commit 4fcee4f

File tree

6 files changed

+14
-15
lines changed

6 files changed

+14
-15
lines changed

core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ private[spark] object TaskMetrics extends Logging {
306306
def fromAccumulators(accums: Seq[AccumulatorV2[_, _]]): TaskMetrics = {
307307
val tm = new TaskMetrics
308308
for (acc <- accums) {
309-
val name = AccumulatorContext.get(acc.id).flatMap(_.name)
309+
val name = acc.name
310310
if (name.isDefined && tm.nameToAccums.contains(name.get)) {
311311
val tmAcc = tm.nameToAccums(name.get).asInstanceOf[AccumulatorV2[Any, Any]]
312312
tmAcc.metadata = acc.metadata

core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark._
2727
import org.apache.spark.TaskState.TaskState
2828
import org.apache.spark.internal.Logging
2929
import org.apache.spark.serializer.SerializerInstance
30-
import org.apache.spark.util.{AccumulatorContext, LongAccumulator, ThreadUtils, Utils}
30+
import org.apache.spark.util.{LongAccumulator, ThreadUtils, Utils}
3131

3232
/**
3333
* Runs a thread pool that deserializes and remotely fetches (if necessary) task results.
@@ -100,8 +100,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
100100
// We need to do this here on the driver because if we did this on the executors then
101101
// we would have to serialize the result again after updating the size.
102102
result.accumUpdates = result.accumUpdates.map { a =>
103-
val accName = AccumulatorContext.get(a.id).flatMap(_.name)
104-
if (accName == Some(InternalAccumulator.RESULT_SIZE)) {
103+
if (a.name == Some(InternalAccumulator.RESULT_SIZE)) {
105104
val acc = a.asInstanceOf[LongAccumulator]
106105
assert(acc.sum == 0L, "task result size should not have been set on the executors")
107106
acc.setValue(size.toLong)

core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,12 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable {
8484
* Returns the name of this accumulator, can only be called after registration.
8585
*/
8686
final def name: Option[String] = {
87-
assertMetadataNotNull()
88-
metadata.name
87+
if (atDriverSide) {
88+
AccumulatorContext.get(id).flatMap(_.metadata.name)
89+
} else {
90+
assertMetadataNotNull()
91+
metadata.name
92+
}
8993
}
9094

9195
/**

core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import org.scalatest.concurrent.Eventually._
3636
import org.apache.spark._
3737
import org.apache.spark.storage.TaskResultBlockId
3838
import org.apache.spark.TestUtils.JavaSourceFromString
39-
import org.apache.spark.util.{AccumulatorContext, MutableURLClassLoader, RpcUtils, Utils}
39+
import org.apache.spark.util.{MutableURLClassLoader, RpcUtils, Utils}
4040

4141

4242
/**
@@ -242,12 +242,8 @@ class TaskResultGetterSuite extends SparkFunSuite with BeforeAndAfter with Local
242242
assert(resultGetter.taskResults.size === 1)
243243
val resBefore = resultGetter.taskResults.head
244244
val resAfter = captor.getValue
245-
val resSizeBefore = resBefore.accumUpdates.find { acc =>
246-
AccumulatorContext.get(acc.id).flatMap(_.name) == Some(RESULT_SIZE)
247-
}.map(_.value)
248-
val resSizeAfter = resAfter.accumUpdates.find { acc =>
249-
AccumulatorContext.get(acc.id).flatMap(_.name) == Some(RESULT_SIZE)
250-
}.map(_.value)
245+
val resSizeBefore = resBefore.accumUpdates.find(_.name == Some(RESULT_SIZE)).map(_.value)
246+
val resSizeAfter = resAfter.accumUpdates.find(_.name == Some(RESULT_SIZE)).map(_.value)
251247
assert(resSizeBefore.exists(_ == 0L))
252248
assert(resSizeAfter.exists(_.toString.toLong > 0L))
253249
}

core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with
293293
val execId = "exe-1"
294294

295295
def makeTaskMetrics(base: Int): TaskMetrics = {
296-
val taskMetrics = TaskMetrics.empty
296+
val taskMetrics = TaskMetrics.registered
297297
val shuffleReadMetrics = taskMetrics.createTempShuffleReadMetrics()
298298
val shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics
299299
val inputMetrics = taskMetrics.inputMetrics

core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,7 @@ private[spark] object JsonProtocolSuite extends Assertions {
830830
hasHadoopInput: Boolean,
831831
hasOutput: Boolean,
832832
hasRecords: Boolean = true) = {
833-
val t = TaskMetrics.empty
833+
val t = TaskMetrics.registered
834834
// Set CPU times same as wall times for testing purpose
835835
t.setExecutorDeserializeTime(a)
836836
t.setExecutorDeserializeCpuTime(a)

0 commit comments

Comments
 (0)