Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 97 additions & 14 deletions core/src/main/scala/org/apache/spark/BarrierCoordinator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@

package org.apache.spark

import java.nio.charset.StandardCharsets.UTF_8
import java.util.{Timer, TimerTask}
import java.util.concurrent.ConcurrentHashMap
import java.util.function.Consumer

import scala.collection.mutable.ArrayBuffer

import org.json4s.JsonAST._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.{compact, render}

import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted}
Expand Down Expand Up @@ -99,10 +104,15 @@ private[spark] class BarrierCoordinator(
// reset when a barrier() call fails due to timeout.
private var barrierEpoch: Int = 0

// An array of RPCCallContexts for barrier tasks that are waiting for reply of a barrier()
// call.
// An Array of RPCCallContexts for barrier tasks that have made a blocking runBarrier() call
private val requesters: ArrayBuffer[RpcCallContext] = new ArrayBuffer[RpcCallContext](numTasks)

// An Array of allGather messages for barrier tasks that have made a blocking runBarrier() call
private val allGatherMessages: ArrayBuffer[String] = new Array[String](numTasks).to[ArrayBuffer]

// The blocking requestMethod called by tasks to sync up for this stage attempt
private var requestMethodToSync: RequestMethod.Value = RequestMethod.BARRIER

// A timer task that ensures we may timeout for a barrier() call.
private var timerTask: TimerTask = null

Expand Down Expand Up @@ -130,9 +140,32 @@ private[spark] class BarrierCoordinator(

// Process the global sync request. The barrier() call succeed if collected enough requests
// within a configured time, otherwise fail all the pending requests.
def handleRequest(requester: RpcCallContext, request: RequestToSync): Unit = synchronized {
def handleRequest(
requester: RpcCallContext,
request: RequestToSync
): Unit = synchronized {
val taskId = request.taskAttemptId
val epoch = request.barrierEpoch
val requestMethod = request.requestMethod
val partitionId = request.partitionId
val allGatherMessage = request match {
case ag: AllGatherRequestToSync => ag.allGatherMessage
case _ => ""
}

if (requesters.size == 0) {
requestMethodToSync = requestMethod
}

if (requestMethodToSync != requestMethod) {
requesters.foreach(
_.sendFailure(new SparkException(s"$barrierId tried to use requestMethod " +
s"`$requestMethod` during barrier epoch $barrierEpoch, which does not match " +
s"the current synchronized requestMethod `$requestMethodToSync`"
))
)
cleanupBarrierStage(barrierId)
}

// Require the number of tasks is correctly set from the BarrierTaskContext.
require(request.numTasks == numTasks, s"Number of tasks of $barrierId is " +
Expand All @@ -153,6 +186,7 @@ private[spark] class BarrierCoordinator(
}
// Add the requester to array of RPCCallContexts pending for reply.
requesters += requester
allGatherMessages(partitionId) = allGatherMessage
logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " +
s"$taskId, current progress: ${requesters.size}/$numTasks.")
if (maybeFinishAllRequesters(requesters, numTasks)) {
Expand All @@ -173,7 +207,13 @@ private[spark] class BarrierCoordinator(
requesters: ArrayBuffer[RpcCallContext],
numTasks: Int): Boolean = {
if (requesters.size == numTasks) {
requesters.foreach(_.reply(()))
requestMethodToSync match {
case RequestMethod.BARRIER =>
requesters.foreach(_.reply(""))
case RequestMethod.ALL_GATHER =>
val json: String = compact(render(allGatherMessages))
requesters.foreach(_.reply(json))
}
true
} else {
false
Expand All @@ -199,11 +239,11 @@ private[spark] class BarrierCoordinator(
}

override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _) =>
case request: RequestToSync =>
// Get or init the ContextBarrierState correspond to the stage attempt.
val barrierId = ContextBarrierId(stageId, stageAttemptId)
val barrierId = ContextBarrierId(request.stageId, request.stageAttemptId)
states.computeIfAbsent(barrierId,
(key: ContextBarrierId) => new ContextBarrierState(key, numTasks))
(key: ContextBarrierId) => new ContextBarrierState(key, request.numTasks))
val barrierState = states.get(barrierId)

barrierState.handleRequest(context, request)
Expand All @@ -216,6 +256,16 @@ private[spark] class BarrierCoordinator(

private[spark] sealed trait BarrierCoordinatorMessage extends Serializable

private[spark] sealed trait RequestToSync extends BarrierCoordinatorMessage {
def numTasks: Int
def stageId: Int
def stageAttemptId: Int
def taskAttemptId: Long
def barrierEpoch: Int
def partitionId: Int
def requestMethod: RequestMethod.Value
}

/**
* A global sync request message from BarrierTaskContext, by `barrier()` call. Each request is
* identified by stageId + stageAttemptId + barrierEpoch.
Expand All @@ -224,11 +274,44 @@ private[spark] sealed trait BarrierCoordinatorMessage extends Serializable
* @param stageId ID of current stage
* @param stageAttemptId ID of current stage attempt
* @param taskAttemptId Unique ID of current task
* @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls.
* @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls
* @param partitionId ID of the current partition the task is assigned to
* @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator
*/
private[spark] case class RequestToSync(
numTasks: Int,
stageId: Int,
stageAttemptId: Int,
taskAttemptId: Long,
barrierEpoch: Int) extends BarrierCoordinatorMessage
private[spark] case class BarrierRequestToSync(
numTasks: Int,
stageId: Int,
stageAttemptId: Int,
taskAttemptId: Long,
barrierEpoch: Int,
partitionId: Int,
requestMethod: RequestMethod.Value
) extends RequestToSync

/**
* A global sync request message from BarrierTaskContext, by `allGather()` call. Each request is
* identified by stageId + stageAttemptId + barrierEpoch.
*
* @param numTasks The number of global sync requests the BarrierCoordinator shall receive
* @param stageId ID of current stage
* @param stageAttemptId ID of current stage attempt
* @param taskAttemptId Unique ID of current task
* @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls
* @param partitionId ID of the current partition the task is assigned to
* @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator
* @param allGatherMessage Message sent from the BarrierTaskContext if requestMethod is ALL_GATHER
*/
private[spark] case class AllGatherRequestToSync(
numTasks: Int,
stageId: Int,
stageAttemptId: Int,
taskAttemptId: Long,
barrierEpoch: Int,
partitionId: Int,
requestMethod: RequestMethod.Value,
allGatherMessage: String
) extends RequestToSync

private[spark] object RequestMethod extends Enumeration {
val BARRIER, ALL_GATHER = Value
}
153 changes: 106 additions & 47 deletions core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,19 @@

package org.apache.spark

import java.nio.charset.StandardCharsets.UTF_8
import java.util.{Properties, Timer, TimerTask}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.TimeoutException
import scala.concurrent.duration._
import scala.language.postfixOps

import org.json4s.DefaultFormats
import org.json4s.JsonAST._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.parse

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.executor.TaskMetrics
Expand Down Expand Up @@ -59,49 +67,31 @@ class BarrierTaskContext private[spark] (
// from different tasks within the same barrier stage attempt to succeed.
private lazy val numTasks = getTaskInfos().size

/**
* :: Experimental ::
* Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to
* MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same
* stage have reached this routine.
*
* CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all
* possible code branches. Otherwise, you may get the job hanging or a SparkException after
* timeout. Some examples of '''misuses''' are listed below:
* 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it
* shall lead to timeout of the function call.
* {{{
* rdd.barrier().mapPartitions { iter =>
* val context = BarrierTaskContext.get()
* if (context.partitionId() == 0) {
* // Do nothing.
* } else {
* context.barrier()
* }
* iter
* }
* }}}
*
* 2. Include barrier() function in a try-catch code block, this may lead to timeout of the
* second function call.
* {{{
* rdd.barrier().mapPartitions { iter =>
* val context = BarrierTaskContext.get()
* try {
* // Do something that might throw an Exception.
* doSomething()
* context.barrier()
* } catch {
* case e: Exception => logWarning("...", e)
* }
* context.barrier()
* iter
* }
* }}}
*/
@Experimental
@Since("2.4.0")
def barrier(): Unit = {
private def getRequestToSync(
numTasks: Int,
stageId: Int,
stageAttemptNumber: Int,
taskAttemptId: Long,
barrierEpoch: Int,
partitionId: Int,
requestMethod: RequestMethod.Value,
allGatherMessage: String
): RequestToSync = {
requestMethod match {
case RequestMethod.BARRIER =>
BarrierRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
barrierEpoch, partitionId, requestMethod)
case RequestMethod.ALL_GATHER =>
AllGatherRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
barrierEpoch, partitionId, requestMethod, allGatherMessage)
}
}

private def runBarrier(
requestMethod: RequestMethod.Value,
allGatherMessage: String = ""
): String = {

logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " +
s"the global sync, current barrier epoch is $barrierEpoch.")
logTrace("Current callSite: " + Utils.getCallSite())
Expand All @@ -118,10 +108,12 @@ class BarrierTaskContext private[spark] (
// Log the update of global sync every 60 seconds.
timer.schedule(timerTask, 60000, 60000)

var json: String = ""

try {
val abortableRpcFuture = barrierCoordinator.askAbortable[Unit](
message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
barrierEpoch),
val abortableRpcFuture = barrierCoordinator.askAbortable[String](
message = getRequestToSync(numTasks, stageId, stageAttemptNumber,
taskAttemptId, barrierEpoch, partitionId, requestMethod, allGatherMessage),
// Set a fixed timeout for RPC here, so users shall get a SparkException thrown by
// BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework.
timeout = new RpcTimeout(365.days, "barrierTimeout"))
Expand All @@ -133,7 +125,7 @@ class BarrierTaskContext private[spark] (
while (!abortableRpcFuture.toFuture.isCompleted) {
// wait RPC future for at most 1 second
try {
ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second)
json = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second)
} catch {
case _: TimeoutException | _: InterruptedException =>
// If `TimeoutException` thrown, waiting RPC future reach 1 second.
Expand Down Expand Up @@ -163,6 +155,73 @@ class BarrierTaskContext private[spark] (
timerTask.cancel()
timer.purge()
}
json
}

/**
* :: Experimental ::
* Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to
* MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same
* stage have reached this routine.
*
* CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all
* possible code branches. Otherwise, you may get the job hanging or a SparkException after
* timeout. Some examples of '''misuses''' are listed below:
* 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it
* shall lead to timeout of the function call.
* {{{
* rdd.barrier().mapPartitions { iter =>
* val context = BarrierTaskContext.get()
* if (context.partitionId() == 0) {
* // Do nothing.
* } else {
* context.barrier()
* }
* iter
* }
* }}}
*
* 2. Include barrier() function in a try-catch code block, this may lead to timeout of the
* second function call.
* {{{
* rdd.barrier().mapPartitions { iter =>
* val context = BarrierTaskContext.get()
* try {
* // Do something that might throw an Exception.
* doSomething()
* context.barrier()
* } catch {
* case e: Exception => logWarning("...", e)
* }
* context.barrier()
* iter
* }
* }}}
*/
@Experimental
@Since("2.4.0")
def barrier(): Unit = {
runBarrier(RequestMethod.BARRIER)
()
}

/**
* :: Experimental ::
* Blocks until all tasks in the same stage have reached this routine. Each task passes in
* a message and returns with a list of all the messages passed in by each of those tasks.
*
* CAUTION! The allGather method requires the same precautions as the barrier method
*
* The message is type String rather than Array[Byte] because it is more convenient for
* the user at the cost of worse performance.
*/
@Experimental
@Since("3.0.0")
def allGather(message: String): ArrayBuffer[String] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just out of curiosity, why return an ArrayBuffer[String] instead of an Array[String] here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

friendly ping @jiangxb1987 @sarthfrey

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair point; why not just Seq?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't have a particular reason in mind for ArrayBuffer[String] over Array[String], @zhengruifeng do you think the latter is preferable here, and if so, why? The returned collection is indexed and sorted by partition ID so I preferred those over Seq which is vague about whether it is naturally indexed or linear.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK sure IndexedSeq. or Array is fine. Just something immutable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, will submit a PR.

val json = runBarrier(RequestMethod.ALL_GATHER, message)
val jsonArray = parse(json)
implicit val formats = DefaultFormats
ArrayBuffer(jsonArray.extract[Array[String]]: _*)
}

/**
Expand Down
Loading