-
Notifications
You must be signed in to change notification settings - Fork 29k
[WIP][SPARK-24375][Prototype] Support barrier scheduling #21494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,78 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.barrier | ||
|
|
||
| import java.util.{Timer, TimerTask} | ||
|
|
||
| import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} | ||
|
|
||
| class BarrierCoordinator( | ||
| numTasks: Int, | ||
| timeout: Long, | ||
| override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint { | ||
|
|
||
| private var epoch = 0 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will |
||
|
|
||
| private val timer = new Timer("BarrierCoordinator epoch increment timer") | ||
|
|
||
| private val syncRequests = new scala.collection.mutable.ArrayBuffer[RpcCallContext](numTasks) | ||
|
|
||
| private def replyIfGetAllSyncRequest(): Unit = { | ||
| if (syncRequests.length == numTasks) { | ||
| syncRequests.foreach(_.reply(())) | ||
| syncRequests.clear() | ||
| epoch += 1 | ||
| } | ||
| } | ||
|
|
||
| override def receive: PartialFunction[Any, Unit] = { | ||
| case IncreaseEpoch(previousEpoch) => | ||
| if (previousEpoch == epoch) { | ||
| syncRequests.foreach(_.sendFailure(new RuntimeException( | ||
| s"The coordinator cannot get all barrier sync requests within $timeout ms."))) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Have we considered to increase incrementally the time out when we can't get all barrier sync requests at an epoch? |
||
| syncRequests.clear() | ||
| epoch += 1 | ||
| } | ||
| } | ||
|
|
||
| override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { | ||
| case RequestToSync(epoch) => | ||
| if (epoch == this.epoch) { | ||
| if (syncRequests.isEmpty) { | ||
| val currentEpoch = epoch | ||
| timer.schedule(new TimerTask { | ||
| override def run(): Unit = { | ||
| // self can be null after this RPC endpoint is stopped. | ||
| if (self != null) self.send(IncreaseEpoch(currentEpoch)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Once this epoch fails to sync, the stage will be failed and resubmitted. I think it will begin from new task set, so There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. register task level barriers sequence and hierarchy may be? |
||
| } | ||
| }, timeout) | ||
| } | ||
|
|
||
| syncRequests += context | ||
| replyIfGetAllSyncRequest() | ||
| } | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if (epoch == this.epoch) {
...
} else { // Received RpcCallContext from failed previousEpoch.
context.sendFailure(new RuntimeException(
s"The coordinator cannot get all barrier sync requests within $timeout ms.")))
} |
||
| } | ||
|
|
||
| override def onStop(): Unit = timer.cancel() | ||
| } | ||
|
|
||
| private[barrier] sealed trait BarrierCoordinatorMessage extends Serializable | ||
|
|
||
| private[barrier] case class RequestToSync(epoch: Int) extends BarrierCoordinatorMessage | ||
|
|
||
| private[barrier] case class IncreaseEpoch(previousEpoch: Int) extends BarrierCoordinatorMessage | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.barrier | ||
|
|
||
| import scala.reflect.ClassTag | ||
|
|
||
| import org.apache.spark.{Partition, TaskContext} | ||
| import org.apache.spark.rdd.RDD | ||
|
|
||
|
|
||
| /** | ||
| * An RDD that supports running MPI programme. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| */ | ||
| class BarrierRDD[T: ClassTag](var prev: RDD[T]) extends RDD[T](prev) { | ||
|
|
||
| override def isBarrier(): Boolean = true | ||
|
|
||
| override def getPartitions: Array[Partition] = prev.partitions | ||
|
|
||
| override def compute(split: Partition, context: TaskContext): Iterator[T] = { | ||
| prev.iterator(split, context) | ||
| } | ||
|
|
||
| override def clearDependencies() { | ||
| super.clearDependencies() | ||
| prev = null | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.barrier | ||
|
|
||
| import java.util.Properties | ||
|
|
||
| import org.apache.spark.{SparkEnv, TaskContextImpl} | ||
| import org.apache.spark.executor.TaskMetrics | ||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.memory.TaskMemoryManager | ||
| import org.apache.spark.metrics.MetricsSystem | ||
| import org.apache.spark.util.RpcUtils | ||
|
|
||
| class BarrierTaskContext( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BarrierTaskContextImpl? |
||
| override val stageId: Int, | ||
| override val stageAttemptNumber: Int, | ||
| override val partitionId: Int, | ||
| override val taskAttemptId: Long, | ||
| override val attemptNumber: Int, | ||
| override val taskMemoryManager: TaskMemoryManager, | ||
| localProperties: Properties, | ||
| @transient private val metricsSystem: MetricsSystem, | ||
| // The default value is only used in tests. | ||
| override val taskMetrics: TaskMetrics = TaskMetrics.empty) | ||
| // TODO make this extends TaskContext | ||
| extends TaskContextImpl(stageId, stageAttemptNumber, partitionId, taskAttemptId, attemptNumber, | ||
| taskMemoryManager, localProperties, metricsSystem, taskMetrics) | ||
| with Logging { | ||
|
|
||
| private val barrierCoordinator = { | ||
| val env = SparkEnv.get | ||
| RpcUtils.makeDriverRef(s"barrier-$stageId-$stageAttemptNumber", env.conf, env.rpcEnv) | ||
| } | ||
|
|
||
| private var epoch = 0 | ||
|
|
||
| /** | ||
| * Returns an Array of executor IDs that the barrier tasks are running on. | ||
| */ | ||
| def hosts(): Array[String] = { | ||
| val hostsStr = localProperties.getProperty("hosts", "") | ||
| hostsStr.trim().split(",").map(_.trim()) | ||
| } | ||
|
|
||
| /** | ||
| * Wait to sync all the barrier tasks in the same taskSet. | ||
| */ | ||
| def barrier(): Unit = synchronized { | ||
| barrierCoordinator.askSync[Unit](RequestToSync(epoch)) | ||
| epoch += 1 | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1062,7 +1062,7 @@ class DAGScheduler( | |
| stage.pendingPartitions += id | ||
| new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber, | ||
| taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId), | ||
| Option(sc.applicationId), sc.applicationAttemptId) | ||
| Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier) | ||
| } | ||
|
|
||
| case stage: ResultStage => | ||
|
|
@@ -1072,7 +1072,8 @@ class DAGScheduler( | |
| val locs = taskIdToLocations(id) | ||
| new ResultTask(stage.id, stage.latestInfo.attemptNumber, | ||
| taskBinary, part, locs, id, properties, serializedTaskMetrics, | ||
| Option(jobId), Option(sc.applicationId), sc.applicationAttemptId) | ||
| Option(jobId), Option(sc.applicationId), sc.applicationAttemptId, | ||
| stage.rdd.isBarrier()) | ||
| } | ||
| } | ||
| } catch { | ||
|
|
@@ -1310,6 +1311,44 @@ class DAGScheduler( | |
| } | ||
| } | ||
|
|
||
| case failure: TaskFailedReason if task.isBarrier => | ||
| // Always fail the current stage and retry all the tasks when a barrier task fail. | ||
| val failedStage = stageIdToStage(task.stageId) | ||
| logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + | ||
| "due to a barrier task failed.") | ||
| val message = "Stage failed because a barrier task finished unsuccessfully. " + | ||
| s"${failure.toErrorString}" | ||
| try { // cancelTasks will fail if a SchedulerBackend does not implement killTask | ||
| taskScheduler.cancelTasks(stageId, interruptThread = false) | ||
| } catch { | ||
| case e: UnsupportedOperationException => | ||
| logInfo(s"Could not cancel tasks for stage $stageId", e) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Under barrier execution, will it be a problem if we can not cancel tasks? |
||
| } | ||
| markStageAsFinished(failedStage, Some(message)) | ||
|
|
||
| failedStage.fetchFailedAttemptIds.add(task.stageAttemptId) | ||
| val shouldAbortStage = | ||
| failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts || | ||
| disallowStageRetryForTest | ||
|
|
||
| if (shouldAbortStage) { | ||
| val abortMessage = if (disallowStageRetryForTest) { | ||
| "Barrier stage will not retry stage due to testing config" | ||
| } else { | ||
| s"""$failedStage (${failedStage.name}) | ||
| |has failed the maximum allowable number of | ||
| |times: $maxConsecutiveStageAttempts. | ||
| |Most recent failure reason: $message""".stripMargin.replaceAll("\n", " ") | ||
| } | ||
| abortStage(failedStage, abortMessage, None) | ||
| } else { // update failedStages and make sure a ResubmitFailedStages event is enqueued | ||
| failedStages += failedStage | ||
| logInfo(s"Resubmitting $failedStage (${failedStage.name}) due to barrier stage failure.") | ||
| messageScheduler.schedule(new Runnable { | ||
| override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) | ||
| }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) | ||
| } | ||
|
|
||
| case Resubmitted => | ||
| logInfo("Resubmitted " + task + ", so marking it as still running") | ||
| stage match { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so this would be language dependent? would need something for R runner too?