Skip to content
Closed
235 changes: 235 additions & 0 deletions core/src/main/scala/org/apache/spark/BarrierCoordinator.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark

import java.util.{Timer, TimerTask}
import java.util.concurrent.ConcurrentHashMap
import java.util.function.{Consumer, Function}

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted}

/**
* For each barrier stage attempt, only at most one barrier() call can be active at any time, thus
* we can use (stageId, stageAttemptId) to identify the stage attempt where the barrier() call is
* from.
*/
private case class ContextBarrierId(stageId: Int, stageAttemptId: Int) {
override def toString: String = s"Stage $stageId (Attempt $stageAttemptId)"
}

/**
* A coordinator that handles all global sync requests from BarrierTaskContext. Each global sync
* request is generated by `BarrierTaskContext.barrier()`, and identified by
* stageId + stageAttemptId + barrierEpoch. Reply all the blocking global sync requests upon
* all the requests for a group of `barrier()` calls are received. If the coordinator is unable to
* collect enough global sync requests within a configured time, fail all the requests and return
* an Exception with timeout message.
*/
private[spark] class BarrierCoordinator(
timeoutInSecs: Long,
listenerBus: LiveListenerBus,
override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging {

// TODO SPARK-25030 Create a Timer() in the mainClass submitted to SparkSubmit makes it unable to
// fetch result, we shall fix the issue.
private lazy val timer = new Timer("BarrierCoordinator barrier epoch increment timer")

// Listen to StageCompleted event, clear corresponding ContextBarrierState.
private val listener = new SparkListener {
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
val stageInfo = stageCompleted.stageInfo
val barrierId = ContextBarrierId(stageInfo.stageId, stageInfo.attemptNumber)
// Clear ContextBarrierState from a finished stage attempt.
cleanupBarrierStage(barrierId)
}
}

// Record all active stage attempts that make barrier() call(s), and the corresponding internal
// state.
private val states = new ConcurrentHashMap[ContextBarrierId, ContextBarrierState]

override def onStart(): Unit = {
super.onStart()
listenerBus.addToStatusQueue(listener)
}

override def onStop(): Unit = {
try {
states.forEachValue(1, clearStateConsumer)
states.clear()
listenerBus.removeListener(listener)
} finally {
super.onStop()
}
}

/**
* Provide the current state of a barrier() call. A state is created when a new stage attempt
* sends out a barrier() call, and recycled on stage completed.
*
* @param barrierId Identifier of the barrier stage that make a barrier() call.
* @param numTasks Number of tasks of the barrier stage, all barrier() calls from the stage shall
* collect `numTasks` requests to succeed.
*/
private class ContextBarrierState(
val barrierId: ContextBarrierId,
val numTasks: Int) {

// There may be multiple barrier() calls from a barrier stage attempt, `barrierEpoch` is used
// to identify each barrier() call. It shall get increased when a barrier() call succeeds, or
// reset when a barrier() call fails due to timeout.
private var barrierEpoch: Int = 0

// An array of RPCCallContexts for barrier tasks that are waiting for reply of a barrier()
// call.
private val requesters: ArrayBuffer[RpcCallContext] = new ArrayBuffer[RpcCallContext](numTasks)

// A timer task that ensures we may timeout for a barrier() call.
private var timerTask: TimerTask = null

// Init a TimerTask for a barrier() call.
private def initTimerTask(): Unit = {
timerTask = new TimerTask {
override def run(): Unit = synchronized {
// Timeout current barrier() call, fail all the sync requests.
requesters.foreach(_.sendFailure(new SparkException("The coordinator didn't get all " +
s"barrier sync requests for barrier epoch $barrierEpoch from $barrierId within " +
s"$timeoutInSecs second(s).")))
cleanupBarrierStage(barrierId)
}
}
}

// Cancel the current active TimerTask and release resources.
private def cancelTimerTask(): Unit = {
if (timerTask != null) {
timerTask.cancel()
timerTask = null
}
}

// Process the global sync request. The barrier() call succeed if collected enough requests
// within a configured time, otherwise fail all the pending requests.
def handleRequest(requester: RpcCallContext, request: RequestToSync): Unit = synchronized {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why synchronized? I think we only access ContextBarrierState in the RPC thread, which is single thread.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah i see, the timer thread also accesses ContextBarrierState, the code in timer thread also need to be synchronized.

val taskId = request.taskAttemptId
val epoch = request.barrierEpoch

// Require the number of tasks is correctly set from the BarrierTaskContext.
require(request.numTasks == numTasks, s"Number of tasks of $barrierId is " +
s"${request.numTasks} from Task $taskId, previously it was $numTasks.")

// Check whether the epoch from the barrier tasks matches current barrierEpoch.
logInfo(s"Current barrier epoch for $barrierId is $barrierEpoch.")
if (epoch != barrierEpoch) {
requester.sendFailure(new SparkException(s"The request to sync of $barrierId with " +
s"barrier epoch $barrierEpoch has already finished. Maybe task $taskId is not " +
"properly killed."))
} else {
// If this is the first sync message received for a barrier() call, start timer to ensure
// we may timeout for the sync.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We create ContextBarrierState when we receive the first sync message, I think it's more clear to create the timer when creating ContextBarrierState, so that we don't need the if here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is true if you have only one barrier() call in a barrier task. Otherwise, if the first barrier() call succeed, the RPCCallContext array shall get cleared and the barrier epoch increased, but we still reuse the same ContextBarrierState for forthcoming barrier() calls.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah i see, makes sense

if (requesters.isEmpty) {
initTimerTask()
timer.schedule(timerTask, timeoutInSecs * 1000)
}
// Add the requester to array of RPCCallContexts pending for reply.
requesters += requester
logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " +
s"$taskId, current progress: ${requesters.size}/$numTasks.")
if (maybeFinishAllRequesters(requesters, numTasks)) {
// Finished current barrier() call successfully, clean up ContextBarrierState and
// increase the barrier epoch.
logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received all updates from " +
s"tasks, finished successfully.")
barrierEpoch += 1
requesters.clear()
cancelTimerTask()
}
}
}

// Finish all the blocking barrier sync requests from a stage attempt successfully if we
// have received all the sync requests.
private def maybeFinishAllRequesters(
requesters: ArrayBuffer[RpcCallContext],
numTasks: Int): Boolean = {
if (requesters.size == numTasks) {
requesters.foreach(_.reply(()))
true
} else {
false
}
}

// Cleanup the internal state of a barrier stage attempt.
def clear(): Unit = synchronized {
// The global sync fails so the stage is expected to retry another attempt, all sync
// messages come from current stage attempt shall fail.
barrierEpoch = -1
Copy link
Member

@gatorsmile gatorsmile Aug 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, you expect we will issue the following exception after we change it to -1?

new SparkException(s"The request to sync of $barrierId with " +
          s"barrier epoch $barrierEpoch has already finished. Maybe task $taskId is not " +
          "properly killed.")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea

requesters.clear()
cancelTimerTask()
}
}

// Clean up the [[ContextBarrierState]] that correspond to a specific stage attempt.
private def cleanupBarrierStage(barrierId: ContextBarrierId): Unit = {
val barrierState = states.remove(barrierId)
if (barrierState != null) {
barrierState.clear()
}
}

override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _) =>
// Get or init the ContextBarrierState correspond to the stage attempt.
val barrierId = ContextBarrierId(stageId, stageAttemptId)
states.computeIfAbsent(barrierId, new Function[ContextBarrierId, ContextBarrierState] {
override def apply(key: ContextBarrierId): ContextBarrierState =
new ContextBarrierState(key, numTasks)
})
val barrierState = states.get(barrierId)

barrierState.handleRequest(context, request)
}

private val clearStateConsumer = new Consumer[ContextBarrierState] {
override def accept(state: ContextBarrierState) = state.clear()
}
}

private[spark] sealed trait BarrierCoordinatorMessage extends Serializable

/**
* A global sync request message from BarrierTaskContext, by `barrier()` call. Each request is
* identified by stageId + stageAttemptId + barrierEpoch.
*
* @param numTasks The number of global sync requests the BarrierCoordinator shall receive
* @param stageId ID of current stage
* @param stageAttemptId ID of current stage attempt
* @param taskAttemptId Unique ID of current task
* @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls.
*/
private[spark] case class RequestToSync(
numTasks: Int,
stageId: Int,
stageAttemptId: Int,
taskAttemptId: Long,
barrierEpoch: Int) extends BarrierCoordinatorMessage
62 changes: 60 additions & 2 deletions core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@

package org.apache.spark

import java.util.Properties
import java.util.{Properties, Timer, TimerTask}

import scala.concurrent.duration._
import scala.language.postfixOps

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.rpc.{RpcEndpointRef, RpcTimeout}
import org.apache.spark.util.{RpcUtils, Utils}

/** A [[TaskContext]] with extra info and tooling for a barrier stage. */
class BarrierTaskContext(
Expand All @@ -39,6 +44,22 @@ class BarrierTaskContext(
extends TaskContextImpl(stageId, stageAttemptNumber, partitionId, taskAttemptId, attemptNumber,
taskMemoryManager, localProperties, metricsSystem, taskMetrics) {

// Find the driver side RPCEndpointRef of the coordinator that handles all the barrier() calls.
private val barrierCoordinator: RpcEndpointRef = {
val env = SparkEnv.get
RpcUtils.makeDriverRef("barrierSync", env.conf, env.rpcEnv)
}

private val timer = new Timer("Barrier task timer for barrier() calls.")

// Local barrierEpoch that identify a barrier() call from current task, it shall be identical
// with the driver side epoch.
private var barrierEpoch = 0

// Number of tasks of the current barrier stage, a barrier() call must collect enough requests
// from different tasks within the same barrier stage attempt to succeed.
private lazy val numTasks = getTaskInfos().size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be a def.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If change it to a def then we have to call getTaskInfos() every time, the current lazy val shall only call getTaskInfos() once.


/**
* :: Experimental ::
* Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to
Expand Down Expand Up @@ -80,7 +101,44 @@ class BarrierTaskContext(
@Experimental
@Since("2.4.0")
def barrier(): Unit = {
// TODO SPARK-24817 implement global barrier.
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " +
s"the global sync, current barrier epoch is $barrierEpoch.")
logTrace("Current callSite: " + Utils.getCallSite())

val startTime = System.currentTimeMillis()
val timerTask = new TimerTask {
override def run(): Unit = {
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) waiting " +
s"under the global sync since $startTime, has been waiting for " +
s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch " +
s"is $barrierEpoch.")
}
}
// Log the update of global sync every 60 seconds.
timer.schedule(timerTask, 60000, 60000)

try {
barrierCoordinator.askSync[Unit](
message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
barrierEpoch),
// Set a fixed timeout for RPC here, so users shall get a SparkException thrown by
// BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework.
timeout = new RpcTimeout(31536000 /* = 3600 * 24 * 365 */ seconds, "barrierTimeout"))
barrierEpoch += 1
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) finished " +
"global sync successfully, waited for " +
s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch is " +
s"$barrierEpoch.")
} catch {
case e: SparkException =>
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) failed " +
"to perform global sync, waited for " +
s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch " +
s"is $barrierEpoch.")
throw e
} finally {
timerTask.cancel()
}
}

/**
Expand Down
12 changes: 6 additions & 6 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1930,6 +1930,12 @@ class SparkContext(config: SparkConf) extends Logging {
Utils.tryLogNonFatalError {
_executorAllocationManager.foreach(_.stop())
}
if (_dagScheduler != null) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to fix #21898 (comment) , previously LiveListenerBus was stopped before we stop DAGScheduler.

Utils.tryLogNonFatalError {
_dagScheduler.stop()
}
_dagScheduler = null
}
if (_listenerBusStarted) {
Utils.tryLogNonFatalError {
listenerBus.stop()
Expand All @@ -1939,12 +1945,6 @@ class SparkContext(config: SparkConf) extends Logging {
Utils.tryLogNonFatalError {
_eventLogger.foreach(_.stop())
}
if (_dagScheduler != null) {
Utils.tryLogNonFatalError {
_dagScheduler.stop()
}
_dagScheduler = null
}
if (env != null && _heartbeatReceiver != null) {
Utils.tryLogNonFatalError {
env.rpcEnv.stop(_heartbeatReceiver)
Expand Down
10 changes: 10 additions & 0 deletions core/src/main/scala/org/apache/spark/internal/config/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -567,4 +567,14 @@ package object config {
.intConf
.checkValue(v => v > 0, "The value should be a positive integer.")
.createWithDefault(2000)

private[spark] val BARRIER_SYNC_TIMEOUT =
ConfigBuilder("spark.barrier.sync.timeout")
.doc("The timeout in seconds for each barrier() call from a barrier task. If the " +
"coordinator didn't receive all the sync messages from barrier tasks within the " +
"configed time, throw a SparkException to fail all the tasks. The default value is set " +
"to 31536000(3600 * 24 * 365) so the barrier() call shall wait for one year.")
.timeConf(TimeUnit.SECONDS)
.checkValue(v => v > 0, "The value should be a positive time value.")
.createWithDefaultString("365d")
}
Loading