diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 5663055129d19..04faf7f87cf2b 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -21,7 +21,7 @@ import java.util.{Timer, TimerTask} import java.util.concurrent.ConcurrentHashMap import java.util.function.Consumer -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, HashSet} import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} @@ -106,9 +106,11 @@ private[spark] class BarrierCoordinator( // The messages will be replied to all tasks once sync finished. private val messages = Array.ofDim[String](numTasks) - // The request method which is called inside this barrier sync. All tasks should make sure - // that they're calling the same method within the same barrier sync phase. - private var requestMethod: RequestMethod.Value = _ + // Request methods collected from tasks inside this barrier sync. All tasks should make sure + // that they're calling the same method within the same barrier sync phase. In other words, + // the size of requestMethods should always be 1 for a legitimate barrier sync. Otherwise, + // the barrier sync would fail if the size of requestMethods becomes greater than 1. + private val requestMethods = new HashSet[RequestMethod.Value] // A timer task that ensures we may timeout for a barrier() call. private var timerTask: TimerTask = null @@ -141,17 +143,14 @@ private[spark] class BarrierCoordinator( val taskId = request.taskAttemptId val epoch = request.barrierEpoch val curReqMethod = request.requestMethod - - if (requesters.isEmpty) { - requestMethod = curReqMethod - } else if (requestMethod != curReqMethod) { - requesters.foreach( - _.sendFailure(new SparkException(s"$barrierId tried to use requestMethod " + - s"`$curReqMethod` during barrier epoch $barrierEpoch, which does not match " + - s"the current synchronized requestMethod `$requestMethod`" - )) - ) - cleanupBarrierStage(barrierId) + requestMethods.add(curReqMethod) + if (requestMethods.size > 1) { + val error = new SparkException(s"Different barrier sync types found for the " + + s"sync $barrierId: ${requestMethods.mkString(", ")}. Please use the " + + s"same barrier sync type within a single sync.") + (requesters :+ requester).foreach(_.sendFailure(error)) + clear() + return } // Require the number of tasks is correctly set from the BarrierTaskContext. @@ -184,6 +183,7 @@ private[spark] class BarrierCoordinator( s"tasks, finished successfully.") barrierEpoch += 1 requesters.clear() + requestMethods.clear() cancelTimerTask() } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index c4e5e7c700652..f1578e7d91a52 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -95,10 +95,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with val error = intercept[SparkException] { rdd2.collect() }.getMessage - assert( - error.contains("does not match the current synchronized requestMethod") || - error.contains("not properly killed") - ) + assert(error.contains("Different barrier sync types found")) } test("successively sync with allGather and barrier") {