-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-24817][Core] Implement BarrierTaskContext.barrier() #21898
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
dbde1e4
b690f67
2696f18
330a26b
7e413b4
da44790
3ced829
2f23e44
67dcf17
e29e3b6
33a8926
16ee90e
53aa316
da52db2
8e888b5
a8fa8db
ab49fed
027ca71
1f71e65
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,235 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark | ||
|
|
||
| import java.util.{Timer, TimerTask} | ||
| import java.util.concurrent.ConcurrentHashMap | ||
| import java.util.function.{Consumer, Function} | ||
|
|
||
| import scala.collection.mutable.ArrayBuffer | ||
|
|
||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} | ||
| import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted} | ||
|
|
||
| /** | ||
| * For each barrier stage attempt, only at most one barrier() call can be active at any time, thus | ||
| * we can use (stageId, stageAttemptId) to identify the stage attempt where the barrier() call is | ||
| * from. | ||
| */ | ||
| private case class ContextBarrierId(stageId: Int, stageAttemptId: Int) { | ||
| override def toString: String = s"Stage $stageId (Attempt $stageAttemptId)" | ||
| } | ||
|
|
||
| /** | ||
| * A coordinator that handles all global sync requests from BarrierTaskContext. Each global sync | ||
| * request is generated by `BarrierTaskContext.barrier()`, and identified by | ||
| * stageId + stageAttemptId + barrierEpoch. Reply all the blocking global sync requests upon | ||
| * all the requests for a group of `barrier()` calls are received. If the coordinator is unable to | ||
| * collect enough global sync requests within a configured time, fail all the requests and return | ||
| * an Exception with timeout message. | ||
| */ | ||
| private[spark] class BarrierCoordinator( | ||
| timeoutInSecs: Long, | ||
| listenerBus: LiveListenerBus, | ||
| override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { | ||
|
|
||
| // TODO SPARK-25030 Create a Timer() in the mainClass submitted to SparkSubmit makes it unable to | ||
| // fetch result, we shall fix the issue. | ||
| private lazy val timer = new Timer("BarrierCoordinator barrier epoch increment timer") | ||
|
|
||
| // Listen to StageCompleted event, clear corresponding ContextBarrierState. | ||
| private val listener = new SparkListener { | ||
| override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { | ||
| val stageInfo = stageCompleted.stageInfo | ||
| val barrierId = ContextBarrierId(stageInfo.stageId, stageInfo.attemptNumber) | ||
| // Clear ContextBarrierState from a finished stage attempt. | ||
| cleanupBarrierStage(barrierId) | ||
| } | ||
| } | ||
|
|
||
| // Record all active stage attempts that make barrier() call(s), and the corresponding internal | ||
| // state. | ||
| private val states = new ConcurrentHashMap[ContextBarrierId, ContextBarrierState] | ||
|
|
||
| override def onStart(): Unit = { | ||
| super.onStart() | ||
| listenerBus.addToStatusQueue(listener) | ||
| } | ||
|
|
||
| override def onStop(): Unit = { | ||
| try { | ||
| states.forEachValue(1, clearStateConsumer) | ||
| states.clear() | ||
| listenerBus.removeListener(listener) | ||
| } finally { | ||
| super.onStop() | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Provide the current state of a barrier() call. A state is created when a new stage attempt | ||
| * sends out a barrier() call, and recycled on stage completed. | ||
| * | ||
| * @param barrierId Identifier of the barrier stage that make a barrier() call. | ||
| * @param numTasks Number of tasks of the barrier stage, all barrier() calls from the stage shall | ||
| * collect `numTasks` requests to succeed. | ||
| */ | ||
| private class ContextBarrierState( | ||
| val barrierId: ContextBarrierId, | ||
| val numTasks: Int) { | ||
|
|
||
| // There may be multiple barrier() calls from a barrier stage attempt, `barrierEpoch` is used | ||
| // to identify each barrier() call. It shall get increased when a barrier() call succeeds, or | ||
| // reset when a barrier() call fails due to timeout. | ||
| private var barrierEpoch: Int = 0 | ||
|
|
||
| // An array of RPCCallContexts for barrier tasks that are waiting for reply of a barrier() | ||
| // call. | ||
| private val requesters: ArrayBuffer[RpcCallContext] = new ArrayBuffer[RpcCallContext](numTasks) | ||
|
|
||
| // A timer task that ensures we may timeout for a barrier() call. | ||
| private var timerTask: TimerTask = null | ||
|
|
||
| // Init a TimerTask for a barrier() call. | ||
| private def initTimerTask(): Unit = { | ||
| timerTask = new TimerTask { | ||
| override def run(): Unit = synchronized { | ||
| // Timeout current barrier() call, fail all the sync requests. | ||
| requesters.foreach(_.sendFailure(new SparkException("The coordinator didn't get all " + | ||
| s"barrier sync requests for barrier epoch $barrierEpoch from $barrierId within " + | ||
| s"$timeoutInSecs second(s)."))) | ||
| cleanupBarrierStage(barrierId) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Cancel the current active TimerTask and release resources. | ||
| private def cancelTimerTask(): Unit = { | ||
| if (timerTask != null) { | ||
| timerTask.cancel() | ||
| timerTask = null | ||
| } | ||
| } | ||
|
|
||
| // Process the global sync request. The barrier() call succeed if collected enough requests | ||
| // within a configured time, otherwise fail all the pending requests. | ||
| def handleRequest(requester: RpcCallContext, request: RequestToSync): Unit = synchronized { | ||
| val taskId = request.taskAttemptId | ||
| val epoch = request.barrierEpoch | ||
|
|
||
| // Require the number of tasks is correctly set from the BarrierTaskContext. | ||
| require(request.numTasks == numTasks, s"Number of tasks of $barrierId is " + | ||
| s"${request.numTasks} from Task $taskId, previously it was $numTasks.") | ||
|
|
||
| // Check whether the epoch from the barrier tasks matches current barrierEpoch. | ||
| logInfo(s"Current barrier epoch for $barrierId is $barrierEpoch.") | ||
| if (epoch != barrierEpoch) { | ||
| requester.sendFailure(new SparkException(s"The request to sync of $barrierId with " + | ||
| s"barrier epoch $barrierEpoch has already finished. Maybe task $taskId is not " + | ||
| "properly killed.")) | ||
| } else { | ||
| // If this is the first sync message received for a barrier() call, start timer to ensure | ||
| // we may timeout for the sync. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We create
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is true if you have only one
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah i see, makes sense |
||
| if (requesters.isEmpty) { | ||
| initTimerTask() | ||
| timer.schedule(timerTask, timeoutInSecs * 1000) | ||
| } | ||
| // Add the requester to array of RPCCallContexts pending for reply. | ||
| requesters += requester | ||
| logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " + | ||
| s"$taskId, current progress: ${requesters.size}/$numTasks.") | ||
| if (maybeFinishAllRequesters(requesters, numTasks)) { | ||
| // Finished current barrier() call successfully, clean up ContextBarrierState and | ||
| // increase the barrier epoch. | ||
| logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received all updates from " + | ||
| s"tasks, finished successfully.") | ||
| barrierEpoch += 1 | ||
| requesters.clear() | ||
| cancelTimerTask() | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // Finish all the blocking barrier sync requests from a stage attempt successfully if we | ||
| // have received all the sync requests. | ||
| private def maybeFinishAllRequesters( | ||
| requesters: ArrayBuffer[RpcCallContext], | ||
| numTasks: Int): Boolean = { | ||
| if (requesters.size == numTasks) { | ||
| requesters.foreach(_.reply(())) | ||
| true | ||
| } else { | ||
| false | ||
| } | ||
| } | ||
|
|
||
| // Cleanup the internal state of a barrier stage attempt. | ||
| def clear(): Unit = synchronized { | ||
| // The global sync fails so the stage is expected to retry another attempt, all sync | ||
| // messages come from current stage attempt shall fail. | ||
| barrierEpoch = -1 | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here, you expect we will issue the following exception after we change it to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea |
||
| requesters.clear() | ||
| cancelTimerTask() | ||
| } | ||
| } | ||
|
|
||
| // Clean up the [[ContextBarrierState]] that correspond to a specific stage attempt. | ||
| private def cleanupBarrierStage(barrierId: ContextBarrierId): Unit = { | ||
| val barrierState = states.remove(barrierId) | ||
| if (barrierState != null) { | ||
| barrierState.clear() | ||
| } | ||
| } | ||
|
|
||
| override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { | ||
| case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _) => | ||
| // Get or init the ContextBarrierState correspond to the stage attempt. | ||
| val barrierId = ContextBarrierId(stageId, stageAttemptId) | ||
| states.computeIfAbsent(barrierId, new Function[ContextBarrierId, ContextBarrierState] { | ||
| override def apply(key: ContextBarrierId): ContextBarrierState = | ||
| new ContextBarrierState(key, numTasks) | ||
| }) | ||
| val barrierState = states.get(barrierId) | ||
|
|
||
| barrierState.handleRequest(context, request) | ||
| } | ||
|
|
||
| private val clearStateConsumer = new Consumer[ContextBarrierState] { | ||
| override def accept(state: ContextBarrierState) = state.clear() | ||
| } | ||
| } | ||
|
|
||
| private[spark] sealed trait BarrierCoordinatorMessage extends Serializable | ||
|
|
||
| /** | ||
| * A global sync request message from BarrierTaskContext, by `barrier()` call. Each request is | ||
| * identified by stageId + stageAttemptId + barrierEpoch. | ||
| * | ||
| * @param numTasks The number of global sync requests the BarrierCoordinator shall receive | ||
| * @param stageId ID of current stage | ||
| * @param stageAttemptId ID of current stage attempt | ||
| * @param taskAttemptId Unique ID of current task | ||
| * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls. | ||
| */ | ||
| private[spark] case class RequestToSync( | ||
| numTasks: Int, | ||
| stageId: Int, | ||
| stageAttemptId: Int, | ||
| taskAttemptId: Long, | ||
| barrierEpoch: Int) extends BarrierCoordinatorMessage | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,12 +17,17 @@ | |
|
|
||
| package org.apache.spark | ||
|
|
||
| import java.util.Properties | ||
| import java.util.{Properties, Timer, TimerTask} | ||
|
|
||
| import scala.concurrent.duration._ | ||
| import scala.language.postfixOps | ||
|
|
||
| import org.apache.spark.annotation.{Experimental, Since} | ||
| import org.apache.spark.executor.TaskMetrics | ||
| import org.apache.spark.memory.TaskMemoryManager | ||
| import org.apache.spark.metrics.MetricsSystem | ||
| import org.apache.spark.rpc.{RpcEndpointRef, RpcTimeout} | ||
| import org.apache.spark.util.{RpcUtils, Utils} | ||
|
|
||
| /** A [[TaskContext]] with extra info and tooling for a barrier stage. */ | ||
| class BarrierTaskContext( | ||
|
|
@@ -39,6 +44,22 @@ class BarrierTaskContext( | |
| extends TaskContextImpl(stageId, stageAttemptNumber, partitionId, taskAttemptId, attemptNumber, | ||
| taskMemoryManager, localProperties, metricsSystem, taskMetrics) { | ||
|
|
||
| // Find the driver side RPCEndpointRef of the coordinator that handles all the barrier() calls. | ||
| private val barrierCoordinator: RpcEndpointRef = { | ||
| val env = SparkEnv.get | ||
| RpcUtils.makeDriverRef("barrierSync", env.conf, env.rpcEnv) | ||
| } | ||
|
|
||
| private val timer = new Timer("Barrier task timer for barrier() calls.") | ||
|
|
||
| // Local barrierEpoch that identify a barrier() call from current task, it shall be identical | ||
| // with the driver side epoch. | ||
| private var barrierEpoch = 0 | ||
|
|
||
| // Number of tasks of the current barrier stage, a barrier() call must collect enough requests | ||
| // from different tasks within the same barrier stage attempt to succeed. | ||
| private lazy val numTasks = getTaskInfos().size | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this can be a
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If change it to a |
||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to | ||
|
|
@@ -80,7 +101,44 @@ class BarrierTaskContext( | |
| @Experimental | ||
| @Since("2.4.0") | ||
| def barrier(): Unit = { | ||
| // TODO SPARK-24817 implement global barrier. | ||
| logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " + | ||
| s"the global sync, current barrier epoch is $barrierEpoch.") | ||
| logTrace("Current callSite: " + Utils.getCallSite()) | ||
|
|
||
| val startTime = System.currentTimeMillis() | ||
| val timerTask = new TimerTask { | ||
| override def run(): Unit = { | ||
| logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) waiting " + | ||
| s"under the global sync since $startTime, has been waiting for " + | ||
| s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch " + | ||
| s"is $barrierEpoch.") | ||
| } | ||
| } | ||
| // Log the update of global sync every 60 seconds. | ||
| timer.schedule(timerTask, 60000, 60000) | ||
|
|
||
| try { | ||
| barrierCoordinator.askSync[Unit]( | ||
| message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, | ||
| barrierEpoch), | ||
| // Set a fixed timeout for RPC here, so users shall get a SparkException thrown by | ||
| // BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework. | ||
| timeout = new RpcTimeout(31536000 /* = 3600 * 24 * 365 */ seconds, "barrierTimeout")) | ||
| barrierEpoch += 1 | ||
| logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) finished " + | ||
| "global sync successfully, waited for " + | ||
| s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch is " + | ||
| s"$barrierEpoch.") | ||
| } catch { | ||
| case e: SparkException => | ||
| logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) failed " + | ||
| "to perform global sync, waited for " + | ||
| s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch " + | ||
| s"is $barrierEpoch.") | ||
| throw e | ||
| } finally { | ||
| timerTask.cancel() | ||
| } | ||
| } | ||
|
|
||
| /** | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1930,6 +1930,12 @@ class SparkContext(config: SparkConf) extends Logging { | |
| Utils.tryLogNonFatalError { | ||
| _executorAllocationManager.foreach(_.stop()) | ||
| } | ||
| if (_dagScheduler != null) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why this change?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is to fix #21898 (comment) , previously LiveListenerBus was stopped before we stop DAGScheduler. |
||
| Utils.tryLogNonFatalError { | ||
| _dagScheduler.stop() | ||
| } | ||
| _dagScheduler = null | ||
| } | ||
| if (_listenerBusStarted) { | ||
| Utils.tryLogNonFatalError { | ||
| listenerBus.stop() | ||
|
|
@@ -1939,12 +1945,6 @@ class SparkContext(config: SparkConf) extends Logging { | |
| Utils.tryLogNonFatalError { | ||
| _eventLogger.foreach(_.stop()) | ||
| } | ||
| if (_dagScheduler != null) { | ||
| Utils.tryLogNonFatalError { | ||
| _dagScheduler.stop() | ||
| } | ||
| _dagScheduler = null | ||
| } | ||
| if (env != null && _heartbeatReceiver != null) { | ||
| Utils.tryLogNonFatalError { | ||
| env.rpcEnv.stop(_heartbeatReceiver) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why
synchronized? I think we only accessContextBarrierStatein the RPC thread, which is single thread.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah i see, the timer thread also accesses
ContextBarrierState, the code in timer thread also need to be synchronized.