Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
7370857
Add all gather method to BarrierTaskContext
sarthfrey Jan 30, 2020
fec40fe
change method to allGather
sarthfrey Jan 30, 2020
390fb1f
fix docstring
sarthfrey Jan 30, 2020
a1229c9
fix test
sarthfrey Jan 30, 2020
2b8a199
test2
sarthfrey Jan 30, 2020
d63eab3
test3
sarthfrey Jan 30, 2020
ebd102c
test4
sarthfrey Jan 30, 2020
7fad912
test5
sarthfrey Jan 30, 2020
62b8a30
Change API to send and receive bytes rather than strings
sarthfrey Jan 30, 2020
f17cdd5
doc fix
sarthfrey Jan 30, 2020
d52f0ba
doc fix 2
sarthfrey Jan 30, 2020
f62a1d5
fix test
sarthfrey Jan 30, 2020
ec198f1
fix test 2
sarthfrey Jan 30, 2020
f7bdd8a
fix test 3
sarthfrey Jan 30, 2020
2079548
fix test 4
sarthfrey Jan 30, 2020
aca81bc
fix test 5
sarthfrey Jan 30, 2020
76ea287
fix test 6
sarthfrey Jan 30, 2020
8a4c450
fix test 7
sarthfrey Jan 30, 2020
adfab5d
fix test final
sarthfrey Jan 30, 2020
149e1f3
add python test
sarthfrey Jan 30, 2020
47c514a
fix test final 2
sarthfrey Jan 30, 2020
c047af8
address review round 1
sarthfrey Feb 1, 2020
2fba607
Change allGather API to accept string over bytes
sarthfrey Feb 9, 2020
ed2d7ef
addressed review feedback round 2
sarthfrey Feb 9, 2020
145203a
comments
sarthfrey Feb 9, 2020
0405817
rm trailing whitespace
sarthfrey Feb 9, 2020
dac63cf
address review feedback round 2
sarthfrey Feb 12, 2020
527c9b8
address review round 3
sarthfrey Feb 12, 2020
6368af7
address review feedback round 4
sarthfrey Feb 12, 2020
d2adfd4
address review round 5
sarthfrey Feb 13, 2020
34efff6
address review round 6
sarthfrey Feb 13, 2020
6dc4447
fix test
sarthfrey Feb 13, 2020
fef813d
retrigger build
sarthfrey Feb 13, 2020
74d14f5
retrigger build
sarthfrey Feb 13, 2020
6398066
add mima exclusion rule
sarthfrey Feb 14, 2020
377d8d2
fix semicolon
sarthfrey Feb 14, 2020
ff7f3dd
fix tests
sarthfrey Feb 18, 2020
24adef3
fix python unit test
sarthfrey Feb 19, 2020
d2fffe1
fix python unit test final
sarthfrey Feb 19, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 97 additions & 14 deletions core/src/main/scala/org/apache/spark/BarrierCoordinator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@

package org.apache.spark

import java.nio.charset.StandardCharsets.UTF_8
import java.util.{Timer, TimerTask}
import java.util.concurrent.ConcurrentHashMap
import java.util.function.Consumer

import scala.collection.mutable.ArrayBuffer

import org.json4s.JsonAST._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.{compact, render}

import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted}
Expand Down Expand Up @@ -99,10 +104,15 @@ private[spark] class BarrierCoordinator(
// reset when a barrier() call fails due to timeout.
private var barrierEpoch: Int = 0

// An array of RPCCallContexts for barrier tasks that are waiting for reply of a barrier()
// call.
// An Array of RPCCallContexts for barrier tasks that have made a blocking runBarrier() call
private val requesters: ArrayBuffer[RpcCallContext] = new ArrayBuffer[RpcCallContext](numTasks)

// An Array of allGather messages for barrier tasks that have made a blocking runBarrier() call
private val allGatherMessages: ArrayBuffer[String] = new Array[String](numTasks).to[ArrayBuffer]

// The blocking requestMethod called by tasks to sync up for this stage attempt
private var requestMethodToSync: RequestMethod.Value = RequestMethod.BARRIER

// A timer task that ensures we may timeout for a barrier() call.
private var timerTask: TimerTask = null

Expand Down Expand Up @@ -130,9 +140,32 @@ private[spark] class BarrierCoordinator(

// Process the global sync request. The barrier() call succeed if collected enough requests
// within a configured time, otherwise fail all the pending requests.
def handleRequest(requester: RpcCallContext, request: RequestToSync): Unit = synchronized {
def handleRequest(
requester: RpcCallContext,
request: RequestToSync
): Unit = synchronized {
val taskId = request.taskAttemptId
val epoch = request.barrierEpoch
val requestMethod = request.requestMethod
val partitionId = request.partitionId
val allGatherMessage = request match {
case ag: AllGatherRequestToSync => ag.allGatherMessage
case _ => ""
}

if (requesters.size == 0) {
requestMethodToSync = requestMethod
}

if (requestMethodToSync != requestMethod) {
requesters.foreach(
_.sendFailure(new SparkException(s"$barrierId tried to use requestMethod " +
s"`$requestMethod` during barrier epoch $barrierEpoch, which does not match " +
s"the current synchronized requestMethod `$requestMethodToSync`"
))
)
cleanupBarrierStage(barrierId)
}

// Require the number of tasks is correctly set from the BarrierTaskContext.
require(request.numTasks == numTasks, s"Number of tasks of $barrierId is " +
Expand All @@ -153,6 +186,7 @@ private[spark] class BarrierCoordinator(
}
// Add the requester to array of RPCCallContexts pending for reply.
requesters += requester
allGatherMessages(partitionId) = allGatherMessage
logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " +
s"$taskId, current progress: ${requesters.size}/$numTasks.")
if (maybeFinishAllRequesters(requesters, numTasks)) {
Expand All @@ -173,7 +207,13 @@ private[spark] class BarrierCoordinator(
requesters: ArrayBuffer[RpcCallContext],
numTasks: Int): Boolean = {
if (requesters.size == numTasks) {
requesters.foreach(_.reply(()))
requestMethodToSync match {
case RequestMethod.BARRIER =>
requesters.foreach(_.reply(""))
case RequestMethod.ALL_GATHER =>
val json: String = compact(render(allGatherMessages))
requesters.foreach(_.reply(json))
}
true
} else {
false
Expand All @@ -199,11 +239,11 @@ private[spark] class BarrierCoordinator(
}

override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _) =>
case request: RequestToSync =>
// Get or init the ContextBarrierState correspond to the stage attempt.
val barrierId = ContextBarrierId(stageId, stageAttemptId)
val barrierId = ContextBarrierId(request.stageId, request.stageAttemptId)
states.computeIfAbsent(barrierId,
(key: ContextBarrierId) => new ContextBarrierState(key, numTasks))
(key: ContextBarrierId) => new ContextBarrierState(key, request.numTasks))
val barrierState = states.get(barrierId)

barrierState.handleRequest(context, request)
Expand All @@ -216,6 +256,16 @@ private[spark] class BarrierCoordinator(

private[spark] sealed trait BarrierCoordinatorMessage extends Serializable

private[spark] sealed trait RequestToSync extends BarrierCoordinatorMessage {
def numTasks: Int
def stageId: Int
def stageAttemptId: Int
def taskAttemptId: Long
def barrierEpoch: Int
def partitionId: Int
def requestMethod: RequestMethod.Value
}

/**
* A global sync request message from BarrierTaskContext, by `barrier()` call. Each request is
* identified by stageId + stageAttemptId + barrierEpoch.
Expand All @@ -224,11 +274,44 @@ private[spark] sealed trait BarrierCoordinatorMessage extends Serializable
* @param stageId ID of current stage
* @param stageAttemptId ID of current stage attempt
* @param taskAttemptId Unique ID of current task
* @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls.
* @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls
* @param partitionId ID of the current partition the task is assigned to
* @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator
*/
private[spark] case class RequestToSync(
numTasks: Int,
stageId: Int,
stageAttemptId: Int,
taskAttemptId: Long,
barrierEpoch: Int) extends BarrierCoordinatorMessage
private[spark] case class BarrierRequestToSync(
numTasks: Int,
stageId: Int,
stageAttemptId: Int,
taskAttemptId: Long,
barrierEpoch: Int,
partitionId: Int,
requestMethod: RequestMethod.Value
) extends RequestToSync

/**
* A global sync request message from BarrierTaskContext, by `allGather()` call. Each request is
* identified by stageId + stageAttemptId + barrierEpoch.
*
* @param numTasks The number of global sync requests the BarrierCoordinator shall receive
* @param stageId ID of current stage
* @param stageAttemptId ID of current stage attempt
* @param taskAttemptId Unique ID of current task
* @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls
* @param partitionId ID of the current partition the task is assigned to
* @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator
* @param allGatherMessage Message sent from the BarrierTaskContext if requestMethod is ALL_GATHER
*/
private[spark] case class AllGatherRequestToSync(
numTasks: Int,
stageId: Int,
stageAttemptId: Int,
taskAttemptId: Long,
barrierEpoch: Int,
partitionId: Int,
requestMethod: RequestMethod.Value,
allGatherMessage: String
) extends RequestToSync

private[spark] object RequestMethod extends Enumeration {
val BARRIER, ALL_GATHER = Value
}
153 changes: 106 additions & 47 deletions core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,19 @@

package org.apache.spark

import java.nio.charset.StandardCharsets.UTF_8
import java.util.{Properties, Timer, TimerTask}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.TimeoutException
import scala.concurrent.duration._
import scala.language.postfixOps

import org.json4s.DefaultFormats
import org.json4s.JsonAST._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.parse

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.executor.TaskMetrics
Expand Down Expand Up @@ -59,49 +67,31 @@ class BarrierTaskContext private[spark] (
// from different tasks within the same barrier stage attempt to succeed.
private lazy val numTasks = getTaskInfos().size

/**
* :: Experimental ::
* Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to
* MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same
* stage have reached this routine.
*
* CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all
* possible code branches. Otherwise, you may get the job hanging or a SparkException after
* timeout. Some examples of '''misuses''' are listed below:
* 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it
* shall lead to timeout of the function call.
* {{{
* rdd.barrier().mapPartitions { iter =>
* val context = BarrierTaskContext.get()
* if (context.partitionId() == 0) {
* // Do nothing.
* } else {
* context.barrier()
* }
* iter
* }
* }}}
*
* 2. Include barrier() function in a try-catch code block, this may lead to timeout of the
* second function call.
* {{{
* rdd.barrier().mapPartitions { iter =>
* val context = BarrierTaskContext.get()
* try {
* // Do something that might throw an Exception.
* doSomething()
* context.barrier()
* } catch {
* case e: Exception => logWarning("...", e)
* }
* context.barrier()
* iter
* }
* }}}
*/
@Experimental
@Since("2.4.0")
def barrier(): Unit = {
private def getRequestToSync(
numTasks: Int,
stageId: Int,
stageAttemptNumber: Int,
taskAttemptId: Long,
barrierEpoch: Int,
partitionId: Int,
requestMethod: RequestMethod.Value,
allGatherMessage: String
): RequestToSync = {
requestMethod match {
case RequestMethod.BARRIER =>
BarrierRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
barrierEpoch, partitionId, requestMethod)
case RequestMethod.ALL_GATHER =>
AllGatherRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
barrierEpoch, partitionId, requestMethod, allGatherMessage)
}
}

private def runBarrier(
requestMethod: RequestMethod.Value,
allGatherMessage: String = ""
): String = {

logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " +
s"the global sync, current barrier epoch is $barrierEpoch.")
logTrace("Current callSite: " + Utils.getCallSite())
Expand All @@ -118,10 +108,12 @@ class BarrierTaskContext private[spark] (
// Log the update of global sync every 60 seconds.
timer.schedule(timerTask, 60000, 60000)

var json: String = ""

try {
val abortableRpcFuture = barrierCoordinator.askAbortable[Unit](
message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
barrierEpoch),
val abortableRpcFuture = barrierCoordinator.askAbortable[String](
message = getRequestToSync(numTasks, stageId, stageAttemptNumber,
taskAttemptId, barrierEpoch, partitionId, requestMethod, allGatherMessage),
// Set a fixed timeout for RPC here, so users shall get a SparkException thrown by
// BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework.
timeout = new RpcTimeout(365.days, "barrierTimeout"))
Expand All @@ -133,7 +125,7 @@ class BarrierTaskContext private[spark] (
while (!abortableRpcFuture.toFuture.isCompleted) {
// wait RPC future for at most 1 second
try {
ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second)
json = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second)
} catch {
case _: TimeoutException | _: InterruptedException =>
// If `TimeoutException` thrown, waiting RPC future reach 1 second.
Expand Down Expand Up @@ -163,6 +155,73 @@ class BarrierTaskContext private[spark] (
timerTask.cancel()
timer.purge()
}
json
}

/**
* :: Experimental ::
* Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to
* MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same
* stage have reached this routine.
*
* CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all
* possible code branches. Otherwise, you may get the job hanging or a SparkException after
* timeout. Some examples of '''misuses''' are listed below:
* 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it
* shall lead to timeout of the function call.
* {{{
* rdd.barrier().mapPartitions { iter =>
* val context = BarrierTaskContext.get()
* if (context.partitionId() == 0) {
* // Do nothing.
* } else {
* context.barrier()
* }
* iter
* }
* }}}
*
* 2. Include barrier() function in a try-catch code block, this may lead to timeout of the
* second function call.
* {{{
* rdd.barrier().mapPartitions { iter =>
* val context = BarrierTaskContext.get()
* try {
* // Do something that might throw an Exception.
* doSomething()
* context.barrier()
* } catch {
* case e: Exception => logWarning("...", e)
* }
* context.barrier()
* iter
* }
* }}}
*/
@Experimental
@Since("2.4.0")
def barrier(): Unit = {
runBarrier(RequestMethod.BARRIER)
()
}

/**
* :: Experimental ::
* Blocks until all tasks in the same stage have reached this routine. Each task passes in
* a message and returns with a list of all the messages passed in by each of those tasks.
*
* CAUTION! The allGather method requires the same precautions as the barrier method
*
* The message is type String rather than Array[Byte] because it is more convenient for
* the user at the cost of worse performance.
*/
@Experimental
@Since("3.0.0")
def allGather(message: String): ArrayBuffer[String] = {
val json = runBarrier(RequestMethod.ALL_GATHER, message)
val jsonArray = parse(json)
implicit val formats = DefaultFormats
ArrayBuffer(jsonArray.extract[Array[String]]: _*)
}

/**
Expand Down
Loading