diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 04faf7f87cf2..8ffccdf664b2 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -176,7 +176,7 @@ private[spark] class BarrierCoordinator( logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " + s"$taskId, current progress: ${requesters.size}/$numTasks.") if (requesters.size == numTasks) { - requesters.foreach(_.reply(messages)) + requesters.foreach(_.reply(messages.clone())) // Finished current barrier() call successfully, clean up ContextBarrierState and // increase the barrier epoch. logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received all updates from " + diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index 4f97003e2ed5..26cd5374fa09 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -367,4 +367,27 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with // double check we kill task success assert(System.currentTimeMillis() - startTime < 5000) } + + test("SPARK-40932, messages of allGather should not been overridden " + + "by the following barrier APIs") { + + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local[2]")) + sc.setLogLevel("INFO") + val rdd = sc.makeRDD(1 to 10, 2) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + // Sleep for a random time before global sync. + Thread.sleep(Random.nextInt(1000)) + // Pass partitionId message in + val message: String = context.partitionId().toString + val messages: Array[String] = context.allGather(message) + context.barrier() + Iterator.single(messages.toList) + } + val messages = rdd2.collect() + // All the task partitionIds are shared across all tasks + assert(messages.length === 2) + assert(messages.forall(_ == List("0", "1"))) + } + }