diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index be5036e82e4b2..5663055129d19 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -17,17 +17,12 @@ package org.apache.spark -import java.nio.charset.StandardCharsets.UTF_8 import java.util.{Timer, TimerTask} import java.util.concurrent.ConcurrentHashMap import java.util.function.Consumer import scala.collection.mutable.ArrayBuffer -import org.json4s.JsonAST._ -import org.json4s.JsonDSL._ -import org.json4s.jackson.JsonMethods.{compact, render} - import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted} @@ -107,11 +102,13 @@ private[spark] class BarrierCoordinator( // An Array of RPCCallContexts for barrier tasks that have made a blocking runBarrier() call private val requesters: ArrayBuffer[RpcCallContext] = new ArrayBuffer[RpcCallContext](numTasks) - // An Array of allGather messages for barrier tasks that have made a blocking runBarrier() call - private val allGatherMessages: ArrayBuffer[String] = new Array[String](numTasks).to[ArrayBuffer] + // Messages from each barrier task that have made a blocking runBarrier() call. + // The messages will be replied to all tasks once sync finished. + private val messages = Array.ofDim[String](numTasks) - // The blocking requestMethod called by tasks to sync up for this stage attempt - private var requestMethodToSync: RequestMethod.Value = RequestMethod.BARRIER + // The request method which is called inside this barrier sync. All tasks should make sure + // that they're calling the same method within the same barrier sync phase. + private var requestMethod: RequestMethod.Value = _ // A timer task that ensures we may timeout for a barrier() call. private var timerTask: TimerTask = null @@ -140,28 +137,18 @@ private[spark] class BarrierCoordinator( // Process the global sync request. The barrier() call succeed if collected enough requests // within a configured time, otherwise fail all the pending requests. - def handleRequest( - requester: RpcCallContext, - request: RequestToSync - ): Unit = synchronized { + def handleRequest(requester: RpcCallContext, request: RequestToSync): Unit = synchronized { val taskId = request.taskAttemptId val epoch = request.barrierEpoch - val requestMethod = request.requestMethod - val partitionId = request.partitionId - val allGatherMessage = request match { - case ag: AllGatherRequestToSync => ag.allGatherMessage - case _ => "" - } - - if (requesters.size == 0) { - requestMethodToSync = requestMethod - } + val curReqMethod = request.requestMethod - if (requestMethodToSync != requestMethod) { + if (requesters.isEmpty) { + requestMethod = curReqMethod + } else if (requestMethod != curReqMethod) { requesters.foreach( _.sendFailure(new SparkException(s"$barrierId tried to use requestMethod " + - s"`$requestMethod` during barrier epoch $barrierEpoch, which does not match " + - s"the current synchronized requestMethod `$requestMethodToSync`" + s"`$curReqMethod` during barrier epoch $barrierEpoch, which does not match " + + s"the current synchronized requestMethod `$requestMethod`" )) ) cleanupBarrierStage(barrierId) @@ -186,10 +173,11 @@ private[spark] class BarrierCoordinator( } // Add the requester to array of RPCCallContexts pending for reply. requesters += requester - allGatherMessages(partitionId) = allGatherMessage + messages(request.partitionId) = request.message logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " + s"$taskId, current progress: ${requesters.size}/$numTasks.") - if (maybeFinishAllRequesters(requesters, numTasks)) { + if (requesters.size == numTasks) { + requesters.foreach(_.reply(messages)) // Finished current barrier() call successfully, clean up ContextBarrierState and // increase the barrier epoch. logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received all updates from " + @@ -201,25 +189,6 @@ private[spark] class BarrierCoordinator( } } - // Finish all the blocking barrier sync requests from a stage attempt successfully if we - // have received all the sync requests. - private def maybeFinishAllRequesters( - requesters: ArrayBuffer[RpcCallContext], - numTasks: Int): Boolean = { - if (requesters.size == numTasks) { - requestMethodToSync match { - case RequestMethod.BARRIER => - requesters.foreach(_.reply("")) - case RequestMethod.ALL_GATHER => - val json: String = compact(render(allGatherMessages)) - requesters.foreach(_.reply(json)) - } - true - } else { - false - } - } - // Cleanup the internal state of a barrier stage attempt. def clear(): Unit = synchronized { // The global sync fails so the stage is expected to retry another attempt, all sync @@ -239,11 +208,11 @@ private[spark] class BarrierCoordinator( } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case request: RequestToSync => + case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _, _, _, _) => // Get or init the ContextBarrierState correspond to the stage attempt. - val barrierId = ContextBarrierId(request.stageId, request.stageAttemptId) + val barrierId = ContextBarrierId(stageId, stageAttemptId) states.computeIfAbsent(barrierId, - (key: ContextBarrierId) => new ContextBarrierState(key, request.numTasks)) + (key: ContextBarrierId) => new ContextBarrierState(key, numTasks)) val barrierState = states.get(barrierId) barrierState.handleRequest(context, request) @@ -256,61 +225,28 @@ private[spark] class BarrierCoordinator( private[spark] sealed trait BarrierCoordinatorMessage extends Serializable -private[spark] sealed trait RequestToSync extends BarrierCoordinatorMessage { - def numTasks: Int - def stageId: Int - def stageAttemptId: Int - def taskAttemptId: Long - def barrierEpoch: Int - def partitionId: Int - def requestMethod: RequestMethod.Value -} - -/** - * A global sync request message from BarrierTaskContext, by `barrier()` call. Each request is - * identified by stageId + stageAttemptId + barrierEpoch. - * - * @param numTasks The number of global sync requests the BarrierCoordinator shall receive - * @param stageId ID of current stage - * @param stageAttemptId ID of current stage attempt - * @param taskAttemptId Unique ID of current task - * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls - * @param partitionId ID of the current partition the task is assigned to - * @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator - */ -private[spark] case class BarrierRequestToSync( - numTasks: Int, - stageId: Int, - stageAttemptId: Int, - taskAttemptId: Long, - barrierEpoch: Int, - partitionId: Int, - requestMethod: RequestMethod.Value -) extends RequestToSync - /** - * A global sync request message from BarrierTaskContext, by `allGather()` call. Each request is + * A global sync request message from BarrierTaskContext. Each request is * identified by stageId + stageAttemptId + barrierEpoch. * * @param numTasks The number of global sync requests the BarrierCoordinator shall receive * @param stageId ID of current stage * @param stageAttemptId ID of current stage attempt * @param taskAttemptId Unique ID of current task - * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls + * @param barrierEpoch ID of a runBarrier() call, a task may consist multiple runBarrier() calls * @param partitionId ID of the current partition the task is assigned to + * @param message Message sent from the BarrierTaskContext * @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator - * @param allGatherMessage Message sent from the BarrierTaskContext if requestMethod is ALL_GATHER */ -private[spark] case class AllGatherRequestToSync( +private[spark] case class RequestToSync( numTasks: Int, stageId: Int, stageAttemptId: Int, taskAttemptId: Long, barrierEpoch: Int, partitionId: Int, - requestMethod: RequestMethod.Value, - allGatherMessage: String -) extends RequestToSync + message: String, + requestMethod: RequestMethod.Value) extends BarrierCoordinatorMessage private[spark] object RequestMethod extends Enumeration { val BARRIER, ALL_GATHER = Value diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index 0c2ceb1a02c7b..06f8024847b90 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -17,20 +17,13 @@ package org.apache.spark -import java.nio.charset.StandardCharsets.UTF_8 import java.util.{Properties, Timer, TimerTask} import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer import scala.concurrent.TimeoutException import scala.concurrent.duration._ import scala.language.postfixOps -import org.json4s.DefaultFormats -import org.json4s.JsonAST._ -import org.json4s.JsonDSL._ -import org.json4s.jackson.JsonMethods.parse - import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.Logging @@ -67,31 +60,7 @@ class BarrierTaskContext private[spark] ( // from different tasks within the same barrier stage attempt to succeed. private lazy val numTasks = getTaskInfos().size - private def getRequestToSync( - numTasks: Int, - stageId: Int, - stageAttemptNumber: Int, - taskAttemptId: Long, - barrierEpoch: Int, - partitionId: Int, - requestMethod: RequestMethod.Value, - allGatherMessage: String - ): RequestToSync = { - requestMethod match { - case RequestMethod.BARRIER => - BarrierRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, - barrierEpoch, partitionId, requestMethod) - case RequestMethod.ALL_GATHER => - AllGatherRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, - barrierEpoch, partitionId, requestMethod, allGatherMessage) - } - } - - private def runBarrier( - requestMethod: RequestMethod.Value, - allGatherMessage: String = "" - ): String = { - + private def runBarrier(message: String, requestMethod: RequestMethod.Value): Array[String] = { logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " + s"the global sync, current barrier epoch is $barrierEpoch.") logTrace("Current callSite: " + Utils.getCallSite()) @@ -108,16 +77,16 @@ class BarrierTaskContext private[spark] ( // Log the update of global sync every 60 seconds. timer.schedule(timerTask, 60000, 60000) - var json: String = "" - try { - val abortableRpcFuture = barrierCoordinator.askAbortable[String]( - message = getRequestToSync(numTasks, stageId, stageAttemptNumber, - taskAttemptId, barrierEpoch, partitionId, requestMethod, allGatherMessage), + val abortableRpcFuture = barrierCoordinator.askAbortable[Array[String]]( + message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, + barrierEpoch, partitionId, message, requestMethod), // Set a fixed timeout for RPC here, so users shall get a SparkException thrown by // BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework. timeout = new RpcTimeout(365.days, "barrierTimeout")) + // messages which consist of all barrier tasks' messages + var messages: Array[String] = null // Wait the RPC future to be completed, but every 1 second it will jump out waiting // and check whether current spark task is killed. If killed, then throw // a `TaskKilledException`, otherwise continue wait RPC until it completes. @@ -125,7 +94,7 @@ class BarrierTaskContext private[spark] ( while (!abortableRpcFuture.toFuture.isCompleted) { // wait RPC future for at most 1 second try { - json = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second) + messages = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second) } catch { case _: TimeoutException | _: InterruptedException => // If `TimeoutException` thrown, waiting RPC future reach 1 second. @@ -144,6 +113,7 @@ class BarrierTaskContext private[spark] ( "global sync successfully, waited for " + s"${MILLISECONDS.toSeconds(System.currentTimeMillis() - startTime)} seconds, " + s"current barrier epoch is $barrierEpoch.") + messages } catch { case e: SparkException => logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) failed " + @@ -155,7 +125,6 @@ class BarrierTaskContext private[spark] ( timerTask.cancel() timer.purge() } - json } /** @@ -200,10 +169,7 @@ class BarrierTaskContext private[spark] ( */ @Experimental @Since("2.4.0") - def barrier(): Unit = { - runBarrier(RequestMethod.BARRIER) - () - } + def barrier(): Unit = runBarrier("", RequestMethod.BARRIER) /** * :: Experimental :: @@ -217,12 +183,7 @@ class BarrierTaskContext private[spark] ( */ @Experimental @Since("3.0.0") - def allGather(message: String): Array[String] = { - val json = runBarrier(RequestMethod.ALL_GATHER, message) - val jsonArray = parse(json) - implicit val formats = DefaultFormats - jsonArray.extract[Array[String]] - } + def allGather(message: String): Array[String] = runBarrier(message, RequestMethod.ALL_GATHER) /** * :: Experimental :: diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 06c9446c7534e..39c4abea933a7 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -414,22 +414,15 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( ) val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) try { - var result: String = "" - requestMethod match { + val messages = requestMethod match { case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => context.asInstanceOf[BarrierTaskContext].barrier() - result = BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS + Array(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS) case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION => - val messages: Array[String] = context.asInstanceOf[BarrierTaskContext].allGather( - message - ) - result = compact(render(JArray( - messages.map( - (message) => JString(message) - ).toList - ))) + context.asInstanceOf[BarrierTaskContext].allGather(message) } - writeUTF(result, out) + out.writeInt(messages.length) + messages.foreach(writeUTF(_, out)) } catch { case e: SparkException => writeUTF(e.getMessage, out) diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index 0dd8be72dc904..bcf1fe2c2aa11 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.scheduler import java.io.File -import scala.collection.mutable.ArrayBuffer import scala.util.Random import org.apache.spark._ diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index e4fc64b732ba7..8f419a5e8446a 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -19,7 +19,7 @@ import json from pyspark.java_gateway import local_connect_and_auth -from pyspark.serializers import write_int, write_with_length, UTF8Deserializer +from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer class TaskContext(object): @@ -133,7 +133,10 @@ def _load_from_socket(port, auth_secret, function, all_gather_message=None): sockfile.flush() # Collect result. - res = UTF8Deserializer().loads(sockfile) + len = read_int(sockfile) + res = [] + for i in range(len): + res.append(UTF8Deserializer().loads(sockfile)) # Release resources. sockfile.close() @@ -232,13 +235,7 @@ def allGather(self, message=""): raise Exception("Not supported to call barrier() before initialize " + "BarrierTaskContext.") else: - gathered_items = _load_from_socket( - self._port, - self._secret, - ALL_GATHER_FUNCTION, - message, - ) - return [e for e in json.loads(gathered_items)] + return _load_from_socket(self._port, self._secret, ALL_GATHER_FUNCTION, message) def getTaskInfos(self): """