From 692b8ab2917ecdc6aef0c1e70c28d243fee70492 Mon Sep 17 00:00:00 2001 From: sarthfrey-db Date: Wed, 29 Jan 2020 22:18:11 -0800 Subject: [PATCH 1/7] Add all gather method to BarrierTaskContext change method to allGather fix docstring fix test test2 test3 test4 test5 Change API to send and receive bytes rather than strings doc fix doc fix 2 fix test fix test 2 fix test 3 fix test 4 fix test 5 fix test 6 fix test 7 fix test final add python test fix test final 2 address review round 1 Change allGather API to accept string over bytes addressed review feedback round 2 comments rm trailing whitespace address review feedback round 2 address review round 3 address review feedback round 4 address review round 5 address review round 6 fix test retrigger build retrigger build add mima exclusion rule fix semicolon fix tests fix python unit test fix python unit test final temp --- .../org/apache/spark/BarrierCoordinator.scala | 111 +++++++++++-- .../org/apache/spark/BarrierTaskContext.scala | 153 ++++++++++++------ .../spark/api/python/PythonRunner.scala | 51 ++++-- .../scheduler/BarrierTaskContextSuite.scala | 74 +++++++++ project/MimaExcludes.scala | 5 +- python/pyspark/taskcontext.py | 60 ++++--- python/pyspark/tests/test_taskcontext.py | 23 +++ 7 files changed, 384 insertions(+), 93 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 4e417679ca663..be5036e82e4b2 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -17,12 +17,17 @@ package org.apache.spark +import java.nio.charset.StandardCharsets.UTF_8 import java.util.{Timer, TimerTask} import java.util.concurrent.ConcurrentHashMap import java.util.function.Consumer import scala.collection.mutable.ArrayBuffer +import org.json4s.JsonAST._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} + import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted} @@ -99,10 +104,15 @@ private[spark] class BarrierCoordinator( // reset when a barrier() call fails due to timeout. private var barrierEpoch: Int = 0 - // An array of RPCCallContexts for barrier tasks that are waiting for reply of a barrier() - // call. + // An Array of RPCCallContexts for barrier tasks that have made a blocking runBarrier() call private val requesters: ArrayBuffer[RpcCallContext] = new ArrayBuffer[RpcCallContext](numTasks) + // An Array of allGather messages for barrier tasks that have made a blocking runBarrier() call + private val allGatherMessages: ArrayBuffer[String] = new Array[String](numTasks).to[ArrayBuffer] + + // The blocking requestMethod called by tasks to sync up for this stage attempt + private var requestMethodToSync: RequestMethod.Value = RequestMethod.BARRIER + // A timer task that ensures we may timeout for a barrier() call. private var timerTask: TimerTask = null @@ -130,9 +140,32 @@ private[spark] class BarrierCoordinator( // Process the global sync request. The barrier() call succeed if collected enough requests // within a configured time, otherwise fail all the pending requests. - def handleRequest(requester: RpcCallContext, request: RequestToSync): Unit = synchronized { + def handleRequest( + requester: RpcCallContext, + request: RequestToSync + ): Unit = synchronized { val taskId = request.taskAttemptId val epoch = request.barrierEpoch + val requestMethod = request.requestMethod + val partitionId = request.partitionId + val allGatherMessage = request match { + case ag: AllGatherRequestToSync => ag.allGatherMessage + case _ => "" + } + + if (requesters.size == 0) { + requestMethodToSync = requestMethod + } + + if (requestMethodToSync != requestMethod) { + requesters.foreach( + _.sendFailure(new SparkException(s"$barrierId tried to use requestMethod " + + s"`$requestMethod` during barrier epoch $barrierEpoch, which does not match " + + s"the current synchronized requestMethod `$requestMethodToSync`" + )) + ) + cleanupBarrierStage(barrierId) + } // Require the number of tasks is correctly set from the BarrierTaskContext. require(request.numTasks == numTasks, s"Number of tasks of $barrierId is " + @@ -153,6 +186,7 @@ private[spark] class BarrierCoordinator( } // Add the requester to array of RPCCallContexts pending for reply. requesters += requester + allGatherMessages(partitionId) = allGatherMessage logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " + s"$taskId, current progress: ${requesters.size}/$numTasks.") if (maybeFinishAllRequesters(requesters, numTasks)) { @@ -173,7 +207,13 @@ private[spark] class BarrierCoordinator( requesters: ArrayBuffer[RpcCallContext], numTasks: Int): Boolean = { if (requesters.size == numTasks) { - requesters.foreach(_.reply(())) + requestMethodToSync match { + case RequestMethod.BARRIER => + requesters.foreach(_.reply("")) + case RequestMethod.ALL_GATHER => + val json: String = compact(render(allGatherMessages)) + requesters.foreach(_.reply(json)) + } true } else { false @@ -199,11 +239,11 @@ private[spark] class BarrierCoordinator( } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _) => + case request: RequestToSync => // Get or init the ContextBarrierState correspond to the stage attempt. - val barrierId = ContextBarrierId(stageId, stageAttemptId) + val barrierId = ContextBarrierId(request.stageId, request.stageAttemptId) states.computeIfAbsent(barrierId, - (key: ContextBarrierId) => new ContextBarrierState(key, numTasks)) + (key: ContextBarrierId) => new ContextBarrierState(key, request.numTasks)) val barrierState = states.get(barrierId) barrierState.handleRequest(context, request) @@ -216,6 +256,16 @@ private[spark] class BarrierCoordinator( private[spark] sealed trait BarrierCoordinatorMessage extends Serializable +private[spark] sealed trait RequestToSync extends BarrierCoordinatorMessage { + def numTasks: Int + def stageId: Int + def stageAttemptId: Int + def taskAttemptId: Long + def barrierEpoch: Int + def partitionId: Int + def requestMethod: RequestMethod.Value +} + /** * A global sync request message from BarrierTaskContext, by `barrier()` call. Each request is * identified by stageId + stageAttemptId + barrierEpoch. @@ -224,11 +274,44 @@ private[spark] sealed trait BarrierCoordinatorMessage extends Serializable * @param stageId ID of current stage * @param stageAttemptId ID of current stage attempt * @param taskAttemptId Unique ID of current task - * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls. + * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls + * @param partitionId ID of the current partition the task is assigned to + * @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator */ -private[spark] case class RequestToSync( - numTasks: Int, - stageId: Int, - stageAttemptId: Int, - taskAttemptId: Long, - barrierEpoch: Int) extends BarrierCoordinatorMessage +private[spark] case class BarrierRequestToSync( + numTasks: Int, + stageId: Int, + stageAttemptId: Int, + taskAttemptId: Long, + barrierEpoch: Int, + partitionId: Int, + requestMethod: RequestMethod.Value +) extends RequestToSync + +/** + * A global sync request message from BarrierTaskContext, by `allGather()` call. Each request is + * identified by stageId + stageAttemptId + barrierEpoch. + * + * @param numTasks The number of global sync requests the BarrierCoordinator shall receive + * @param stageId ID of current stage + * @param stageAttemptId ID of current stage attempt + * @param taskAttemptId Unique ID of current task + * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls + * @param partitionId ID of the current partition the task is assigned to + * @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator + * @param allGatherMessage Message sent from the BarrierTaskContext if requestMethod is ALL_GATHER + */ +private[spark] case class AllGatherRequestToSync( + numTasks: Int, + stageId: Int, + stageAttemptId: Int, + taskAttemptId: Long, + barrierEpoch: Int, + partitionId: Int, + requestMethod: RequestMethod.Value, + allGatherMessage: String +) extends RequestToSync + +private[spark] object RequestMethod extends Enumeration { + val BARRIER, ALL_GATHER = Value +} diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index 3d369802f3023..2263538a11676 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -17,11 +17,19 @@ package org.apache.spark +import java.nio.charset.StandardCharsets.UTF_8 import java.util.{Properties, Timer, TimerTask} import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.concurrent.TimeoutException import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.json4s.DefaultFormats +import org.json4s.JsonAST._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.parse import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.executor.TaskMetrics @@ -59,49 +67,31 @@ class BarrierTaskContext private[spark] ( // from different tasks within the same barrier stage attempt to succeed. private lazy val numTasks = getTaskInfos().size - /** - * :: Experimental :: - * Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to - * MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same - * stage have reached this routine. - * - * CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all - * possible code branches. Otherwise, you may get the job hanging or a SparkException after - * timeout. Some examples of '''misuses''' are listed below: - * 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it - * shall lead to timeout of the function call. - * {{{ - * rdd.barrier().mapPartitions { iter => - * val context = BarrierTaskContext.get() - * if (context.partitionId() == 0) { - * // Do nothing. - * } else { - * context.barrier() - * } - * iter - * } - * }}} - * - * 2. Include barrier() function in a try-catch code block, this may lead to timeout of the - * second function call. - * {{{ - * rdd.barrier().mapPartitions { iter => - * val context = BarrierTaskContext.get() - * try { - * // Do something that might throw an Exception. - * doSomething() - * context.barrier() - * } catch { - * case e: Exception => logWarning("...", e) - * } - * context.barrier() - * iter - * } - * }}} - */ - @Experimental - @Since("2.4.0") - def barrier(): Unit = { + private def getRequestToSync( + numTasks: Int, + stageId: Int, + stageAttemptNumber: Int, + taskAttemptId: Long, + barrierEpoch: Int, + partitionId: Int, + requestMethod: RequestMethod.Value, + allGatherMessage: String + ): RequestToSync = { + requestMethod match { + case RequestMethod.BARRIER => + BarrierRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, + barrierEpoch, partitionId, requestMethod) + case RequestMethod.ALL_GATHER => + AllGatherRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, + barrierEpoch, partitionId, requestMethod, allGatherMessage) + } + } + + private def runBarrier( + requestMethod: RequestMethod.Value, + allGatherMessage: String = "" + ): String = { + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " + s"the global sync, current barrier epoch is $barrierEpoch.") logTrace("Current callSite: " + Utils.getCallSite()) @@ -118,10 +108,12 @@ class BarrierTaskContext private[spark] ( // Log the update of global sync every 60 seconds. timer.schedule(timerTask, 60000, 60000) + var json: String = "" + try { - val abortableRpcFuture = barrierCoordinator.askAbortable[Unit]( - message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, - barrierEpoch), + val abortableRpcFuture = barrierCoordinator.askAbortable[String]( + message = getRequestToSync(numTasks, stageId, stageAttemptNumber, + taskAttemptId, barrierEpoch, partitionId, requestMethod, allGatherMessage), // Set a fixed timeout for RPC here, so users shall get a SparkException thrown by // BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework. timeout = new RpcTimeout(365.days, "barrierTimeout")) @@ -133,7 +125,7 @@ class BarrierTaskContext private[spark] ( while (!abortableRpcFuture.toFuture.isCompleted) { // wait RPC future for at most 1 second try { - ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second) + json = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second) } catch { case _: TimeoutException | _: InterruptedException => // If `TimeoutException` thrown, waiting RPC future reach 1 second. @@ -163,6 +155,73 @@ class BarrierTaskContext private[spark] ( timerTask.cancel() timer.purge() } + json + } + + /** + * :: Experimental :: + * Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to + * MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same + * stage have reached this routine. + * + * CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all + * possible code branches. Otherwise, you may get the job hanging or a SparkException after + * timeout. Some examples of '''misuses''' are listed below: + * 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it + * shall lead to timeout of the function call. + * {{{ + * rdd.barrier().mapPartitions { iter => + * val context = BarrierTaskContext.get() + * if (context.partitionId() == 0) { + * // Do nothing. + * } else { + * context.barrier() + * } + * iter + * } + * }}} + * + * 2. Include barrier() function in a try-catch code block, this may lead to timeout of the + * second function call. + * {{{ + * rdd.barrier().mapPartitions { iter => + * val context = BarrierTaskContext.get() + * try { + * // Do something that might throw an Exception. + * doSomething() + * context.barrier() + * } catch { + * case e: Exception => logWarning("...", e) + * } + * context.barrier() + * iter + * } + * }}} + */ + @Experimental + @Since("2.4.0") + def barrier(): Unit = { + runBarrier(RequestMethod.BARRIER) + () + } + + /** + * :: Experimental :: + * Blocks until all tasks in the same stage have reached this routine. Each task passes in + * a message and returns with a list of all the messages passed in by each of those tasks. + * + * CAUTION! The allGather method requires the same precautions as the barrier method + * + * The message is type String rather than Array[Byte] because it is more convenient for + * the user at the cost of worse performance. + */ + @Experimental + @Since("3.0.0") + def allGather(message: String): ArrayBuffer[String] = { + val json = runBarrier(RequestMethod.ALL_GATHER, message) + val jsonArray = parse(json) + implicit val formats = DefaultFormats + ArrayBuffer(jsonArray.extract[Array[String]]: _*) } /** diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 658e0d593a167..fa8bf0fc06358 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -24,8 +24,13 @@ import java.nio.charset.StandardCharsets.UTF_8 import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal +import org.json4s.JsonAST._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} + import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES} @@ -238,13 +243,18 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( sock.setSoTimeout(10000) authHelper.authClient(sock) val input = new DataInputStream(sock.getInputStream()) - input.readInt() match { + val requestMethod = input.readInt() + // The BarrierTaskContext function may wait infinitely, socket shall not timeout + // before the function finishes. + sock.setSoTimeout(0) + requestMethod match { case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => - // The barrier() function may wait infinitely, socket shall not timeout - // before the function finishes. - sock.setSoTimeout(0) - barrierAndServe(sock) - + barrierAndServe(requestMethod, sock) + case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION => + val length = input.readInt() + val message = new Array[Byte](length) + input.readFully(message) + barrierAndServe(requestMethod, sock, new String(message, UTF_8)) case _ => val out = new DataOutputStream(new BufferedOutputStream( sock.getOutputStream)) @@ -395,15 +405,31 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( } /** - * Gateway to call BarrierTaskContext.barrier(). + * Gateway to call BarrierTaskContext methods. */ - def barrierAndServe(sock: Socket): Unit = { - require(serverSocket.isDefined, "No available ServerSocket to redirect the barrier() call.") - + def barrierAndServe(requestMethod: Int, sock: Socket, message: String = ""): Unit = { + require( + serverSocket.isDefined, + "No available ServerSocket to redirect the BarrierTaskContext method call." + ) val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) try { - context.asInstanceOf[BarrierTaskContext].barrier() - writeUTF(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS, out) + var result: String = "" + requestMethod match { + case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => + context.asInstanceOf[BarrierTaskContext].barrier() + result = BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS + case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION => + val messages: ArrayBuffer[String] = context.asInstanceOf[BarrierTaskContext].allGather( + message + ) + result = compact(render(JArray( + messages.map( + (message) => JString(message) + ).toList + ))) + } + writeUTF(result, out) } catch { case e: SparkException => writeUTF(e.getMessage, out) @@ -638,6 +664,7 @@ private[spark] object SpecialLengths { private[spark] object BarrierTaskContextMessageProtocol { val BARRIER_FUNCTION = 1 + val ALL_GATHER_FUNCTION = 2 val BARRIER_RESULT_SUCCESS = "success" val ERROR_UNRECOGNIZED_FUNCTION = "Not recognized function call from python side." } diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index fc8ac38479932..0e6b746c7c230 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler import java.io.File +import scala.collection.mutable.ArrayBuffer import scala.util.Random import org.apache.spark._ @@ -52,6 +53,79 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { assert(times.max - times.min <= 1000) } + test("share messages with allGather() call") { + val conf = new SparkConf() + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + // Sleep for a random time before global sync. + Thread.sleep(Random.nextInt(1000)) + // Pass partitionId message in + val message: String = context.partitionId().toString + val messages: ArrayBuffer[String] = context.allGather(message) + messages.toList.iterator + } + // Take a sorted list of all the partitionId messages + val messages = rdd2.collect().head + // All the task partitionIds are shared + for((x, i) <- messages.view.zipWithIndex) assert(x.toString == i.toString) + } + + test("throw exception if we attempt to synchronize with different blocking calls") { + val conf = new SparkConf() + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + val partitionId = context.partitionId + if (partitionId == 0) { + context.barrier() + } else { + context.allGather(partitionId.toString) + } + Seq(null).iterator + } + val error = intercept[SparkException] { + rdd2.collect() + }.getMessage + assert(error.contains("does not match the current synchronized requestMethod")) + } + + test("successively sync with allGather and barrier") { + val conf = new SparkConf() + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { it => + val context = BarrierTaskContext.get() + // Sleep for a random time before global sync. + Thread.sleep(Random.nextInt(1000)) + context.barrier() + val time1 = System.currentTimeMillis() + // Sleep for a random time before global sync. + Thread.sleep(Random.nextInt(1000)) + // Pass partitionId message in + val message = context.partitionId().toString + val messages = context.allGather(message) + val time2 = System.currentTimeMillis() + Seq((time1, time2)).iterator + } + val times = rdd2.collect() + // All the tasks shall finish the first round of global sync within a short time slot. + val times1 = times.map(_._1) + assert(times1.max - times1.min <= 1000) + + // All the tasks shall finish the second round of global sync within a short time slot. + val times2 = times.map(_._2) + assert(times2.max - times2.min <= 1000) + } + test("support multiple barrier() call within a single task") { initLocalClusterSparkContext() val rdd = sc.makeRDD(1 to 10, 4) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 65ffa228eddec..a48437473d905 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -492,7 +492,10 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegression.setPredictionCol"), // [SPARK-29543][SS][UI] Init structured streaming ui - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryStartedEvent.this") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryStartedEvent.this"), + + // [SPARK-30667][CORE] Add allGather method to BarrierTaskContext + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.RequestToSync") ) // Exclude rules for 2.4.x diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index d648f63338514..3a9afbe985142 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -16,9 +16,10 @@ # from __future__ import print_function +import json from pyspark.java_gateway import local_connect_and_auth -from pyspark.serializers import write_int, UTF8Deserializer +from pyspark.serializers import write_int, write_with_length, UTF8Deserializer class TaskContext(object): @@ -62,7 +63,6 @@ def get(cls): """ Return the currently active TaskContext. This can be called inside of user functions to access contextual information about running tasks. - .. note:: Must be called on the worker, not the driver. Returns None if not initialized. """ return cls._taskContext @@ -107,18 +107,28 @@ def resources(self): BARRIER_FUNCTION = 1 +ALL_GATHER_FUNCTION = 2 -def _load_from_socket(port, auth_secret): +def _load_from_socket(port, auth_secret, function, all_gather_message=None): """ Load data from a given socket, this is a blocking method thus only return when the socket connection has been closed. """ (sockfile, sock) = local_connect_and_auth(port, auth_secret) - # The barrier() call may block forever, so no timeout + + # The call may block forever, so no timeout sock.settimeout(None) - # Make a barrier() function call. - write_int(BARRIER_FUNCTION, sockfile) + + if function == BARRIER_FUNCTION: + # Make a barrier() function call. + write_int(function, sockfile) + elif function == ALL_GATHER_FUNCTION: + # Make a all_gather() function call. + write_int(function, sockfile) + write_with_length(all_gather_message.encode("utf-8"), sockfile) + else: + raise ValueError("Unrecognized function type") sockfile.flush() # Collect result. @@ -135,10 +145,8 @@ class BarrierTaskContext(TaskContext): """ .. note:: Experimental - A :class:`TaskContext` with extra contextual info and tooling for tasks in a barrier stage. Use :func:`BarrierTaskContext.get` to obtain the barrier context for a running barrier task. - .. versionadded:: 2.4.0 """ @@ -160,11 +168,9 @@ def _getOrCreate(cls): def get(cls): """ .. note:: Experimental - Return the currently active :class:`BarrierTaskContext`. This can be called inside of user functions to access contextual information about running tasks. - .. note:: Must be called on the worker, not the driver. Returns None if not initialized. An Exception will raise if it is not in a barrier stage. """ @@ -184,30 +190,49 @@ def _initialize(cls, port, secret): def barrier(self): """ .. note:: Experimental - Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to `MPI_Barrier` function in MPI, this function blocks until all tasks in the same stage have reached this routine. - .. warning:: In a barrier stage, each task much have the same number of `barrier()` calls, in all possible code branches. Otherwise, you may get the job hanging or a SparkException after timeout. - .. versionadded:: 2.4.0 """ if self._port is None or self._secret is None: raise Exception("Not supported to call barrier() before initialize " + "BarrierTaskContext.") else: - _load_from_socket(self._port, self._secret) + _load_from_socket(self._port, self._secret, BARRIER_FUNCTION) - def getTaskInfos(self): + def allGather(self, message=""): """ .. note:: Experimental + This function blocks until all tasks in the same stage have reached this routine. + Each task passes in a message and returns with a list of all the messages passed in + by each of those tasks. + .. warning:: In a barrier stage, each task much have the same number of `allGather()` + calls, in all possible code branches. + Otherwise, you may get the job hanging or a SparkException after timeout. + """ + if not isinstance(message, str): + raise ValueError("Argument `message` must be of type `str`") + elif self._port is None or self._secret is None: + raise Exception("Not supported to call barrier() before initialize " + + "BarrierTaskContext.") + else: + gathered_items = _load_from_socket( + self._port, + self._secret, + ALL_GATHER_FUNCTION, + message, + ) + return [e for e in json.loads(gathered_items)] + def getTaskInfos(self): + """ + .. note:: Experimental Returns :class:`BarrierTaskInfo` for all tasks in this barrier stage, ordered by partition ID. - .. versionadded:: 2.4.0 """ if self._port is None or self._secret is None: @@ -221,11 +246,8 @@ def getTaskInfos(self): class BarrierTaskInfo(object): """ .. note:: Experimental - Carries all task infos of a barrier task. - :var address: The IPv4 address (host:port) of the executor that the barrier task is running on - .. versionadded:: 2.4.0 """ diff --git a/python/pyspark/tests/test_taskcontext.py b/python/pyspark/tests/test_taskcontext.py index 68cfe814762e0..399f5fcac7bf8 100644 --- a/python/pyspark/tests/test_taskcontext.py +++ b/python/pyspark/tests/test_taskcontext.py @@ -135,6 +135,29 @@ def context_barrier(x): times = rdd.barrier().mapPartitions(f).map(context_barrier).collect() self.assertTrue(max(times) - min(times) < 1) +<<<<<<< HEAD +======= + def test_all_gather(self): + """ + Verify that BarrierTaskContext.allGather() performs global sync among all barrier tasks + within a stage and passes messages properly. + """ + rdd = self.sc.parallelize(range(10), 4) + + def f(iterator): + yield sum(iterator) + + def context_barrier(x): + tc = BarrierTaskContext.get() + time.sleep(random.randint(1, 10)) + out = tc.allGather(str(tc.partitionId())) + pids = [int(e) for e in out] + return pids + + pids = rdd.barrier().mapPartitions(f).map(context_barrier).collect()[0] + self.assertEqual(pids, [0, 1, 2, 3]) + +>>>>>>> 7f928deb9b... Add all gather method to BarrierTaskContext def test_barrier_infos(self): """ Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the From 3f8e9bf1a2c2c5b4e5cd093b2e7ca8212a95d29f Mon Sep 17 00:00:00 2001 From: sarthfrey-db Date: Wed, 19 Feb 2020 17:28:01 -0800 Subject: [PATCH 2/7] fix merge --- python/pyspark/tests/test_taskcontext.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/pyspark/tests/test_taskcontext.py b/python/pyspark/tests/test_taskcontext.py index 399f5fcac7bf8..752430a958391 100644 --- a/python/pyspark/tests/test_taskcontext.py +++ b/python/pyspark/tests/test_taskcontext.py @@ -135,8 +135,6 @@ def context_barrier(x): times = rdd.barrier().mapPartitions(f).map(context_barrier).collect() self.assertTrue(max(times) - min(times) < 1) -<<<<<<< HEAD -======= def test_all_gather(self): """ Verify that BarrierTaskContext.allGather() performs global sync among all barrier tasks @@ -157,7 +155,6 @@ def context_barrier(x): pids = rdd.barrier().mapPartitions(f).map(context_barrier).collect()[0] self.assertEqual(pids, [0, 1, 2, 3]) ->>>>>>> 7f928deb9b... Add all gather method to BarrierTaskContext def test_barrier_infos(self): """ Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the From 711803a150841cf24505f62b855f6f9fb721a081 Mon Sep 17 00:00:00 2001 From: sarthfrey-db Date: Wed, 19 Feb 2020 17:32:47 -0800 Subject: [PATCH 3/7] fix spaces --- python/pyspark/taskcontext.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index 3a9afbe985142..2ed180dd6694c 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -63,6 +63,7 @@ def get(cls): """ Return the currently active TaskContext. This can be called inside of user functions to access contextual information about running tasks. + .. note:: Must be called on the worker, not the driver. Returns None if not initialized. """ return cls._taskContext @@ -145,8 +146,10 @@ class BarrierTaskContext(TaskContext): """ .. note:: Experimental + A :class:`TaskContext` with extra contextual info and tooling for tasks in a barrier stage. Use :func:`BarrierTaskContext.get` to obtain the barrier context for a running barrier task. + .. versionadded:: 2.4.0 """ @@ -168,9 +171,11 @@ def _getOrCreate(cls): def get(cls): """ .. note:: Experimental + Return the currently active :class:`BarrierTaskContext`. This can be called inside of user functions to access contextual information about running tasks. + .. note:: Must be called on the worker, not the driver. Returns None if not initialized. An Exception will raise if it is not in a barrier stage. """ @@ -190,12 +195,15 @@ def _initialize(cls, port, secret): def barrier(self): """ .. note:: Experimental + Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to `MPI_Barrier` function in MPI, this function blocks until all tasks in the same stage have reached this routine. + .. warning:: In a barrier stage, each task much have the same number of `barrier()` calls, in all possible code branches. Otherwise, you may get the job hanging or a SparkException after timeout. + .. versionadded:: 2.4.0 """ if self._port is None or self._secret is None: @@ -207,12 +215,16 @@ def barrier(self): def allGather(self, message=""): """ .. note:: Experimental + This function blocks until all tasks in the same stage have reached this routine. Each task passes in a message and returns with a list of all the messages passed in by each of those tasks. + .. warning:: In a barrier stage, each task much have the same number of `allGather()` calls, in all possible code branches. Otherwise, you may get the job hanging or a SparkException after timeout. + + .. versionadded:: 3.0.0 """ if not isinstance(message, str): raise ValueError("Argument `message` must be of type `str`") @@ -233,6 +245,7 @@ def getTaskInfos(self): .. note:: Experimental Returns :class:`BarrierTaskInfo` for all tasks in this barrier stage, ordered by partition ID. + .. versionadded:: 2.4.0 """ if self._port is None or self._secret is None: @@ -246,8 +259,11 @@ def getTaskInfos(self): class BarrierTaskInfo(object): """ .. note:: Experimental + Carries all task infos of a barrier task. + :var address: The IPv4 address (host:port) of the executor that the barrier task is running on + .. versionadded:: 2.4.0 """ From 68176439601bc320e4b2b604f63ffea53e12dde7 Mon Sep 17 00:00:00 2001 From: sarthfrey-db Date: Wed, 19 Feb 2020 17:33:35 -0800 Subject: [PATCH 4/7] fix spaces --- python/pyspark/taskcontext.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index 2ed180dd6694c..3fe539e89b72a 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -63,7 +63,7 @@ def get(cls): """ Return the currently active TaskContext. This can be called inside of user functions to access contextual information about running tasks. - + .. note:: Must be called on the worker, not the driver. Returns None if not initialized. """ return cls._taskContext From c1d1b0eccccf60b54b4268dee3505d98fa4cb59f Mon Sep 17 00:00:00 2001 From: sarthfrey Date: Thu, 20 Feb 2020 00:50:45 -0800 Subject: [PATCH 5/7] fix indent --- python/pyspark/taskcontext.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index 3fe539e89b72a..e4fc64b732ba7 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -243,6 +243,7 @@ def allGather(self, message=""): def getTaskInfos(self): """ .. note:: Experimental + Returns :class:`BarrierTaskInfo` for all tasks in this barrier stage, ordered by partition ID. From 3f1f7091cabbe536140fff1820247ce3a1213f70 Mon Sep 17 00:00:00 2001 From: sarthfrey-db Date: Thu, 20 Feb 2020 15:45:49 -0800 Subject: [PATCH 6/7] fix flaky test --- core/src/main/scala/org/apache/spark/BarrierCoordinator.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index be5036e82e4b2..bc46170ec5c1b 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -164,7 +164,6 @@ private[spark] class BarrierCoordinator( s"the current synchronized requestMethod `$requestMethodToSync`" )) ) - cleanupBarrierStage(barrierId) } // Require the number of tasks is correctly set from the BarrierTaskContext. From 7c259acd3e447450a18bcdeea47d14ac8514b1bc Mon Sep 17 00:00:00 2001 From: sarthfrey-db Date: Thu, 20 Feb 2020 16:24:59 -0800 Subject: [PATCH 7/7] revert and relax test assertion --- .../src/main/scala/org/apache/spark/BarrierCoordinator.scala | 1 + .../org/apache/spark/scheduler/BarrierTaskContextSuite.scala | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index bc46170ec5c1b..be5036e82e4b2 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -164,6 +164,7 @@ private[spark] class BarrierCoordinator( s"the current synchronized requestMethod `$requestMethodToSync`" )) ) + cleanupBarrierStage(barrierId) } // Require the number of tasks is correctly set from the BarrierTaskContext. diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index 0e6b746c7c230..33594c0a50d14 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -93,7 +93,10 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { val error = intercept[SparkException] { rdd2.collect() }.getMessage - assert(error.contains("does not match the current synchronized requestMethod")) + assert( + error.contains("does not match the current synchronized requestMethod") || + error.contains("not properly killed") + ) } test("successively sync with allGather and barrier") {