diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 41b5cab601c3..39c57fa56f2c 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -34,6 +34,8 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) // Common RDD functions + def barrier(): JavaRDD[T] = wrapRDD(rdd.barrier()) + /** * Persist this RDD with the default storage level (`MEMORY_ONLY`). */ diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 41eac10d9b26..819b4474aed5 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -24,7 +24,10 @@ import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ +import py4j.GatewayServer + import org.apache.spark._ +import org.apache.spark.barrier.BarrierTaskContext import org.apache.spark.internal.Logging import org.apache.spark.util._ @@ -179,6 +182,21 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( // Python version of driver PythonRDD.writeUTF(pythonVer, dataOut) // Write out the TaskContextInfo + val isBarrier = context.isInstanceOf[BarrierTaskContext] + dataOut.writeBoolean(isBarrier) + if (isBarrier) { + val port = GatewayServer.DEFAULT_PORT + 2 + context.partitionId() * 2 + val gatewayServer = new GatewayServer( + context.asInstanceOf[BarrierTaskContext], + port, + port + 1, + GatewayServer.DEFAULT_CONNECT_TIMEOUT, + GatewayServer.DEFAULT_READ_TIMEOUT, + null) + // TODO: When to stop it? + gatewayServer.start() + context.addTaskCompletionListener(_ => gatewayServer.shutdown()) + } dataOut.writeInt(context.stageId()) dataOut.writeInt(context.partitionId()) dataOut.writeInt(context.attemptNumber()) diff --git a/core/src/main/scala/org/apache/spark/barrier/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/barrier/BarrierCoordinator.scala new file mode 100644 index 000000000000..573fffeccb82 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/barrier/BarrierCoordinator.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.barrier + +import java.util.{Timer, TimerTask} + +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} + +class BarrierCoordinator( + numTasks: Int, + timeout: Long, + override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint { + + private var epoch = 0 + + private val timer = new Timer("BarrierCoordinator epoch increment timer") + + private val syncRequests = new scala.collection.mutable.ArrayBuffer[RpcCallContext](numTasks) + + private def replyIfGetAllSyncRequest(): Unit = { + if (syncRequests.length == numTasks) { + syncRequests.foreach(_.reply(())) + syncRequests.clear() + epoch += 1 + } + } + + override def receive: PartialFunction[Any, Unit] = { + case IncreaseEpoch(previousEpoch) => + if (previousEpoch == epoch) { + syncRequests.foreach(_.sendFailure(new RuntimeException( + s"The coordinator cannot get all barrier sync requests within $timeout ms."))) + syncRequests.clear() + epoch += 1 + } + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RequestToSync(epoch) => + if (epoch == this.epoch) { + if (syncRequests.isEmpty) { + val currentEpoch = epoch + timer.schedule(new TimerTask { + override def run(): Unit = { + // self can be null after this RPC endpoint is stopped. + if (self != null) self.send(IncreaseEpoch(currentEpoch)) + } + }, timeout) + } + + syncRequests += context + replyIfGetAllSyncRequest() + } + } + + override def onStop(): Unit = timer.cancel() +} + +private[barrier] sealed trait BarrierCoordinatorMessage extends Serializable + +private[barrier] case class RequestToSync(epoch: Int) extends BarrierCoordinatorMessage + +private[barrier] case class IncreaseEpoch(previousEpoch: Int) extends BarrierCoordinatorMessage diff --git a/core/src/main/scala/org/apache/spark/barrier/BarrierRDD.scala b/core/src/main/scala/org/apache/spark/barrier/BarrierRDD.scala new file mode 100644 index 000000000000..60103ceac5b1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/barrier/BarrierRDD.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.barrier + +import scala.reflect.ClassTag + +import org.apache.spark.{Partition, TaskContext} +import org.apache.spark.rdd.RDD + + +/** + * An RDD that supports running MPI programme. + */ +class BarrierRDD[T: ClassTag](var prev: RDD[T]) extends RDD[T](prev) { + + override def isBarrier(): Boolean = true + + override def getPartitions: Array[Partition] = prev.partitions + + override def compute(split: Partition, context: TaskContext): Iterator[T] = { + prev.iterator(split, context) + } + + override def clearDependencies() { + super.clearDependencies() + prev = null + } +} diff --git a/core/src/main/scala/org/apache/spark/barrier/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/barrier/BarrierTaskContext.scala new file mode 100644 index 000000000000..bf381510505b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/barrier/BarrierTaskContext.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.barrier + +import java.util.Properties + +import org.apache.spark.{SparkEnv, TaskContextImpl} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.Logging +import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.util.RpcUtils + +class BarrierTaskContext( + override val stageId: Int, + override val stageAttemptNumber: Int, + override val partitionId: Int, + override val taskAttemptId: Long, + override val attemptNumber: Int, + override val taskMemoryManager: TaskMemoryManager, + localProperties: Properties, + @transient private val metricsSystem: MetricsSystem, + // The default value is only used in tests. + override val taskMetrics: TaskMetrics = TaskMetrics.empty) + // TODO make this extends TaskContext + extends TaskContextImpl(stageId, stageAttemptNumber, partitionId, taskAttemptId, attemptNumber, + taskMemoryManager, localProperties, metricsSystem, taskMetrics) + with Logging { + + private val barrierCoordinator = { + val env = SparkEnv.get + RpcUtils.makeDriverRef(s"barrier-$stageId-$stageAttemptNumber", env.conf, env.rpcEnv) + } + + private var epoch = 0 + + /** + * Returns an Array of executor IDs that the barrier tasks are running on. + */ + def hosts(): Array[String] = { + val hostsStr = localProperties.getProperty("hosts", "") + hostsStr.trim().split(",").map(_.trim()) + } + + /** + * Wait to sync all the barrier tasks in the same taskSet. + */ + def barrier(): Unit = synchronized { + barrierCoordinator.askSync[Unit](RequestToSync(epoch)) + epoch += 1 + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index d5145094ec07..c6f10a0fdbdf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -226,8 +226,9 @@ private[spark] class ClientApp extends SparkApplication { override def start(args: Array[String], conf: SparkConf): Unit = { val driverArgs = new ClientArguments(args) + // TODO remove this hack if (!conf.contains("spark.rpc.askTimeout")) { - conf.set("spark.rpc.askTimeout", "10s") + conf.set("spark.rpc.askTimeout", "900s") } Logger.getRootLogger.setLevel(driverArgs.logLevel) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 0574abdca32a..16ae713de4a7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -35,6 +35,7 @@ import org.apache.spark._ import org.apache.spark.Partitioner._ import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaRDD +import org.apache.spark.barrier.BarrierRDD import org.apache.spark.internal.Logging import org.apache.spark.partial.BoundedDouble import org.apache.spark.partial.CountEvaluator @@ -43,8 +44,7 @@ import org.apache.spark.partial.PartialResult import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.{BoundedPriorityQueue, Utils} import org.apache.spark.util.collection.{OpenHashMap, Utils => collectionUtils} -import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler, - SamplingUtils} +import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler, SamplingUtils} /** * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable, @@ -1317,6 +1317,8 @@ abstract class RDD[T: ClassTag]( } } + def barrier(): BarrierRDD[T] = withScope(new BarrierRDD[T](this)) + /** * Take the first num elements of the RDD. It works by first scanning one partition, and use the * results from that partition to estimate the number of additional partitions needed to satisfy @@ -1839,6 +1841,11 @@ abstract class RDD[T: ClassTag]( def toJavaRDD() : JavaRDD[T] = { new JavaRDD(this)(elementClassTag) } + + /** + * Whether the RDD is a BarrierRDD. + */ + def isBarrier(): Boolean = dependencies.exists(_.rdd.isBarrier()) } diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index 26eaa9aa3d03..489fb8cb8031 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -110,4 +110,6 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag]( super.clearDependencies() prev = null } + + override def isBarrier(): Boolean = false } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 041eade82d3c..9adca7d03bed 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1062,7 +1062,7 @@ class DAGScheduler( stage.pendingPartitions += id new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId), - Option(sc.applicationId), sc.applicationAttemptId) + Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier) } case stage: ResultStage => @@ -1072,7 +1072,8 @@ class DAGScheduler( val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, id, properties, serializedTaskMetrics, - Option(jobId), Option(sc.applicationId), sc.applicationAttemptId) + Option(jobId), Option(sc.applicationId), sc.applicationAttemptId, + stage.rdd.isBarrier()) } } } catch { @@ -1310,6 +1311,44 @@ class DAGScheduler( } } + case failure: TaskFailedReason if task.isBarrier => + // Always fail the current stage and retry all the tasks when a barrier task fail. + val failedStage = stageIdToStage(task.stageId) + logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + + "due to a barrier task failed.") + val message = "Stage failed because a barrier task finished unsuccessfully. " + + s"${failure.toErrorString}" + try { // cancelTasks will fail if a SchedulerBackend does not implement killTask + taskScheduler.cancelTasks(stageId, interruptThread = false) + } catch { + case e: UnsupportedOperationException => + logInfo(s"Could not cancel tasks for stage $stageId", e) + } + markStageAsFinished(failedStage, Some(message)) + + failedStage.fetchFailedAttemptIds.add(task.stageAttemptId) + val shouldAbortStage = + failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts || + disallowStageRetryForTest + + if (shouldAbortStage) { + val abortMessage = if (disallowStageRetryForTest) { + "Barrier stage will not retry stage due to testing config" + } else { + s"""$failedStage (${failedStage.name}) + |has failed the maximum allowable number of + |times: $maxConsecutiveStageAttempts. + |Most recent failure reason: $message""".stripMargin.replaceAll("\n", " ") + } + abortStage(failedStage, abortMessage, None) + } else { // update failedStages and make sure a ResubmitFailedStages event is enqueued + failedStages += failedStage + logInfo(s"Resubmitting $failedStage (${failedStage.name}) due to barrier stage failure.") + messageScheduler.schedule(new Runnable { + override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) + }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) + } + case Resubmitted => logInfo("Resubmitted " + task + ", so marking it as still running") stage match { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index e36c759a4255..b1979e2b9a62 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -48,7 +48,8 @@ import org.apache.spark.rdd.RDD * @param jobId id of the job this task belongs to * @param appId id of the app this task belongs to * @param appAttemptId attempt id of the app this task belongs to - */ + * @param isBarrier whether this task belongs to a barrier sync stage. + */ private[spark] class ResultTask[T, U]( stageId: Int, stageAttemptId: Int, @@ -60,9 +61,10 @@ private[spark] class ResultTask[T, U]( serializedTaskMetrics: Array[Byte], jobId: Option[Int] = None, appId: Option[String] = None, - appAttemptId: Option[String] = None) + appAttemptId: Option[String] = None, + isBarrier: Boolean = false) extends Task[U](stageId, stageAttemptId, partition.index, localProperties, serializedTaskMetrics, - jobId, appId, appAttemptId) + jobId, appId, appAttemptId, isBarrier) with Serializable { @transient private[this] val preferredLocs: Seq[TaskLocation] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 7a25c47e2cab..aafd204fbbbb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -49,6 +49,7 @@ import org.apache.spark.shuffle.ShuffleWriter * @param jobId id of the job this task belongs to * @param appId id of the app this task belongs to * @param appAttemptId attempt id of the app this task belongs to + * @param isBarrier whether this task belongs to a barrier sync stage. */ private[spark] class ShuffleMapTask( stageId: Int, @@ -60,9 +61,10 @@ private[spark] class ShuffleMapTask( serializedTaskMetrics: Array[Byte], jobId: Option[Int] = None, appId: Option[String] = None, - appAttemptId: Option[String] = None) + appAttemptId: Option[String] = None, + isBarrier: Boolean = false) extends Task[MapStatus](stageId, stageAttemptId, partition.index, localProperties, - serializedTaskMetrics, jobId, appId, appAttemptId) + serializedTaskMetrics, jobId, appId, appAttemptId, isBarrier) with Logging { /** A constructor used only in test suites. This does not require passing in an RDD. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index f536fc2a5f0a..4024af95fdc4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -21,6 +21,7 @@ import java.nio.ByteBuffer import java.util.Properties import org.apache.spark._ +import org.apache.spark.barrier.BarrierTaskContext import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.config.APP_CALLER_CONTEXT import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} @@ -49,6 +50,7 @@ import org.apache.spark.util._ * @param jobId id of the job this task belongs to * @param appId id of the app this task belongs to * @param appAttemptId attempt id of the app this task belongs to + * @param isBarrier whether this task belongs to a barrier sync stage. */ private[spark] abstract class Task[T]( val stageId: Int, @@ -60,7 +62,8 @@ private[spark] abstract class Task[T]( SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array(), val jobId: Option[Int] = None, val appId: Option[String] = None, - val appAttemptId: Option[String] = None) extends Serializable { + val appAttemptId: Option[String] = None, + val isBarrier: Boolean = false) extends Serializable { @transient lazy val metrics: TaskMetrics = SparkEnv.get.closureSerializer.newInstance().deserialize(ByteBuffer.wrap(serializedTaskMetrics)) @@ -77,16 +80,30 @@ private[spark] abstract class Task[T]( attemptNumber: Int, metricsSystem: MetricsSystem): T = { SparkEnv.get.blockManager.registerTask(taskAttemptId) - context = new TaskContextImpl( - stageId, - stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal - partitionId, - taskAttemptId, - attemptNumber, - taskMemoryManager, - localProperties, - metricsSystem, - metrics) + context = if (isBarrier) { + new BarrierTaskContext( + stageId, + stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal + partitionId, + taskAttemptId, + attemptNumber, + taskMemoryManager, + localProperties, + metricsSystem, + metrics) + } else { + new TaskContextImpl( + stageId, + stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal + partitionId, + taskAttemptId, + attemptNumber, + taskMemoryManager, + localProperties, + metricsSystem, + metrics) + } + TaskContext.setTaskContext(context) taskThread = Thread.currentThread() diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala index c98b87148e40..bb4a4442b943 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala @@ -50,6 +50,7 @@ private[spark] class TaskDescription( val executorId: String, val name: String, val index: Int, // Index within this task's TaskSet + val partitionId: Int, val addedFiles: Map[String, Long], val addedJars: Map[String, Long], val properties: Properties, @@ -76,6 +77,7 @@ private[spark] object TaskDescription { dataOut.writeUTF(taskDescription.executorId) dataOut.writeUTF(taskDescription.name) dataOut.writeInt(taskDescription.index) + dataOut.writeInt(taskDescription.partitionId) // Write files. serializeStringLongMap(taskDescription.addedFiles, dataOut) @@ -117,6 +119,7 @@ private[spark] object TaskDescription { val executorId = dataIn.readUTF() val name = dataIn.readUTF() val index = dataIn.readInt() + val partitionId = dataIn.readInt() // Read files. val taskFiles = deserializeStringLongMap(dataIn) @@ -138,7 +141,7 @@ private[spark] object TaskDescription { // Create a sub-buffer for the serialized task into its own buffer (to be deserialized later). val serializedTask = byteBuffer.slice() - new TaskDescription(taskId, attemptNumber, executorId, name, index, taskFiles, taskJars, - properties, serializedTask) + new TaskDescription(taskId, attemptNumber, executorId, name, index, partitionId, taskFiles, + taskJars, properties, serializedTask) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 598b62f85a1f..b434d5f6909d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -274,7 +274,9 @@ private[spark] class TaskSchedulerImpl( maxLocality: TaskLocality, shuffledOffers: Seq[WorkerOffer], availableCpus: Array[Int], - tasks: IndexedSeq[ArrayBuffer[TaskDescription]]) : Boolean = { + tasks: IndexedSeq[ArrayBuffer[TaskDescription]], + hosts: ArrayBuffer[String], + taskDescs: ArrayBuffer[TaskDescription]) : Boolean = { var launchedTask = false // nodes and executors that are blacklisted for the entire application have already been // filtered out by this point @@ -291,6 +293,12 @@ private[spark] class TaskSchedulerImpl( executorIdToRunningTaskIds(execId).add(tid) availableCpus(i) -= CPUS_PER_TASK assert(availableCpus(i) >= 0) + // Only update hosts for a barrier task. + if (taskSet.isBarrier) { + // The executor address is expected to be non empty. + hosts += shuffledOffers(i).host + taskDescs += task + } launchedTask = true } } catch { @@ -345,6 +353,7 @@ private[spark] class TaskSchedulerImpl( val shuffledOffers = shuffleOffers(filteredOffers) // Build a list of tasks to assign to each worker. val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores / CPUS_PER_TASK)) + val availableSlots = shuffledOffers.map(o => o.cores / CPUS_PER_TASK).sum val availableCpus = shuffledOffers.map(o => o.cores).toArray val sortedTaskSets = rootPool.getSortedTaskSetQueue for (taskSet <- sortedTaskSets) { @@ -359,17 +368,42 @@ private[spark] class TaskSchedulerImpl( // of locality levels so that it gets a chance to launch local tasks on all of them. // NOTE: the preferredLocality order: PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY for (taskSet <- sortedTaskSets) { - var launchedAnyTask = false - var launchedTaskAtCurrentMaxLocality = false - for (currentMaxLocality <- taskSet.myLocalityLevels) { - do { - launchedTaskAtCurrentMaxLocality = resourceOfferSingleTaskSet( - taskSet, currentMaxLocality, shuffledOffers, availableCpus, tasks) - launchedAnyTask |= launchedTaskAtCurrentMaxLocality - } while (launchedTaskAtCurrentMaxLocality) - } - if (!launchedAnyTask) { - taskSet.abortIfCompletelyBlacklisted(hostToExecutors) + // Skip the barrier taskSet if the available slots are less than the number of pending tasks. + if (taskSet.isBarrier && availableSlots < taskSet.numTasks) { + // Skip the launch process. + } else { + var launchedAnyTask = false + var launchedTaskAtCurrentMaxLocality = false + // Record all the executor IDs assigned barrier tasks on. + val hosts = ArrayBuffer[String]() + val taskDescs = ArrayBuffer[TaskDescription]() + for (currentMaxLocality <- taskSet.myLocalityLevels) { + do { + launchedTaskAtCurrentMaxLocality = resourceOfferSingleTaskSet(taskSet, + currentMaxLocality, shuffledOffers, availableCpus, tasks, hosts, taskDescs) + launchedAnyTask |= launchedTaskAtCurrentMaxLocality + } while (launchedTaskAtCurrentMaxLocality) + } + if (!launchedAnyTask) { + taskSet.abortIfCompletelyBlacklisted(hostToExecutors) + } + if (launchedAnyTask && taskSet.isBarrier) { + // Check whether the barrier tasks are partially launched. + // TODO handle the assert failure case (that can happen when some locality requirements + // are not fulfilled, and we should revert the launched tasks) + assert (taskDescs.size == taskSet.numTasks) + + // materialize the barrier coordinator. + taskSet.barrierCoordinator + + // Update the taskInfos into all the barrier task properties. + val hostsStr = hosts.zip(taskDescs) + // Addresses ordered by partitionId + .sortBy(_._2.partitionId) + .map(_._1) + .mkString(",") + taskDescs.foreach(_.properties.setProperty("hosts", hostsStr)) + } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index a18c66596852..577cd0f5ae94 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -27,6 +27,7 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.TaskState.TaskState +import org.apache.spark.barrier.BarrierCoordinator import org.apache.spark.internal.{config, Logging} import org.apache.spark.scheduler.SchedulingMode._ import org.apache.spark.util.{AccumulatorV2, Clock, SystemClock, Utils} @@ -123,6 +124,21 @@ private[spark] class TaskSetManager( // TODO: We should kill any running task attempts when the task set manager becomes a zombie. private[scheduler] var isZombie = false + private[scheduler] lazy val barrierCoordinator = { + if (isBarrier) { + val timeout = conf.getInt("spark.barrier.sync.timeout", 900000) + val coordinator = new BarrierCoordinator(tasks.length, timeout, env.rpcEnv) + env.rpcEnv.setupEndpoint(s"barrier-${taskSet.stageId}-${taskSet.stageAttemptId}", coordinator) + logInfo("Registered BarrierCoordinator endpoint") + Some(coordinator) + } else { + None + } + } + + // Whether the taskSet is from a barrier stage. + private[scheduler] def isBarrier = taskSet.tasks.nonEmpty && taskSet.tasks(0).isBarrier + // Set of pending tasks for each executor. These collections are actually // treated as stacks, in which new tasks are added to the end of the // ArrayBuffer and removed from the end. This makes it faster to detect @@ -512,6 +528,7 @@ private[spark] class TaskSetManager( execId, taskName, index, + task.partitionId, addedFiles, addedJars, task.localProperties, @@ -525,6 +542,7 @@ private[spark] class TaskSetManager( private def maybeFinishTaskSet() { if (isZombie && runningTasks == 0) { sched.taskSetFinished(this) + barrierCoordinator.foreach(_.stop()) if (tasksSuccessful == numTasks) { blacklistTracker.foreach(_.updateBlacklistForSuccessfulTaskSet( taskSet.stageId, @@ -976,8 +994,8 @@ private[spark] class TaskSetManager( */ override def checkSpeculatableTasks(minTimeToSpeculation: Int): Boolean = { // Can't speculate if we only have one task, and no need to speculate if the task set is a - // zombie. - if (isZombie || numTasks == 1) { + // zombie or is from a barrier stage. + if (isZombie || isBarrier || numTasks == 1) { return false } var foundTasks = false diff --git a/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala index 810b36cddf83..c65feb72a1d1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala @@ -21,4 +21,8 @@ package org.apache.spark.scheduler * Represents free resources available on an executor. */ private[spark] -case class WorkerOffer(executorId: String, host: String, cores: Int) +case class WorkerOffer( + executorId: String, + host: String, + cores: Int, + address: Option[String] = None) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index d8794e8e551a..2752f1182ca0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -243,7 +243,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val activeExecutors = executorDataMap.filterKeys(executorIsAlive) val workOffers = activeExecutors.map { case (id, executorData) => - new WorkerOffer(id, executorData.executorHost, executorData.freeCores) + new WorkerOffer(id, executorData.executorHost, executorData.freeCores, + Some(executorData.executorAddress.hostPort)) }.toIndexedSeq scheduler.resourceOffers(workOffers) } @@ -268,7 +269,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp if (executorIsAlive(executorId)) { val executorData = executorDataMap(executorId) val workOffers = IndexedSeq( - new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)) + new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores, + Some(executorData.executorAddress.hostPort))) scheduler.resourceOffers(workOffers) } else { Seq.empty diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala index 4c614c5c0f60..cf8b0ff4f701 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala @@ -81,7 +81,8 @@ private[spark] class LocalEndpoint( } def reviveOffers() { - val offers = IndexedSeq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) + val offers = IndexedSeq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores, + Some(rpcEnv.address.hostPort))) for (task <- scheduler.resourceOffers(offers).flatten) { freeCores -= scheduler.CPUS_PER_TASK executor.launchTask(executorBackend, task) diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala index e5cccf39f945..bd80e9d75555 100644 --- a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -44,7 +44,7 @@ private[spark] object RpcUtils { /** Returns the default Spark timeout to use for RPC ask operations. */ def askRpcTimeout(conf: SparkConf): RpcTimeout = { - RpcTimeout(conf, Seq("spark.rpc.askTimeout", "spark.network.timeout"), "120s") + RpcTimeout(conf, Seq("spark.rpc.askTimeout", "spark.network.timeout"), "900s") } /** Returns the default Spark timeout to use for RPC remote endpoint lookup. */ diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index ce9f2be1c02d..0dedb5423163 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -627,6 +627,52 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu assert(exc.getCause() != null) stream.close() } + + test("support barrier sync under local mode") { + val conf = new SparkConf().setAppName("test").setMaster("local[2]") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(Seq(1, 2, 3, 4), 2).barrier() + val rdd2 = rdd.mapPartitions { it => + val tc = TaskContext.get.asInstanceOf[org.apache.spark.barrier.BarrierTaskContext] + // If we don't get the expected taskInfos, the job shall abort due to stage failure. + if (tc.hosts().length != 2) { + throw new SparkException("Expected taksInfos length is 2, actual length is " + + s"${tc.hosts().length}.") + } + // println(tc.getTaskInfos().toList) + tc.barrier() + it + } + rdd2.collect + + eventually(timeout(10.seconds)) { + assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) + } + } + + test("support barrier sync under local-cluster mode") { + val conf = new SparkConf() + .setMaster("local-cluster[3, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(Seq(1, 2, 3, 4), 2).barrier() + val rdd2 = rdd.mapPartitions { it => + val tc = TaskContext.get.asInstanceOf[org.apache.spark.barrier.BarrierTaskContext] + // If we don't get the expected taskInfos, the job shall abort due to stage failure. + if (tc.hosts().length != 2) { + throw new SparkException("Expected taksInfos length is 2, actual length is " + + s"${tc.hosts().length}.") + } + // println(tc.getTaskInfos().toList) + tc.barrier() + it + } + rdd2.collect + + eventually(timeout(10.seconds)) { + assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) + } + } } object SparkContextSuite { diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 1a7bebe2c53c..77a7668d3a1d 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -275,6 +275,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug executorId = "", name = "", index = 0, + partitionId = 0, addedFiles = Map[String, Long](), addedJars = Map[String, Long](), properties = new Properties, diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala index 97487ce1d2ca..ba62eec0522d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala @@ -62,6 +62,7 @@ class TaskDescriptionSuite extends SparkFunSuite { executorId = "testExecutor", name = "task for test", index = 19, + partitionId = 1, originalFiles, originalJars, originalProperties, @@ -77,6 +78,7 @@ class TaskDescriptionSuite extends SparkFunSuite { assert(decodedTaskDescription.executorId === originalTaskDescription.executorId) assert(decodedTaskDescription.name === originalTaskDescription.name) assert(decodedTaskDescription.index === originalTaskDescription.index) + assert(decodedTaskDescription.partitionId === originalTaskDescription.partitionId) assert(decodedTaskDescription.addedFiles.equals(originalFiles)) assert(decodedTaskDescription.addedJars.equals(originalJars)) assert(decodedTaskDescription.properties.equals(originalTaskDescription.properties)) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 14d9128502ab..64248a6b3a88 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -206,6 +206,7 @@ def __init__(self, jrdd, ctx, jrdd_deserializer=AutoBatchedSerializer(PickleSeri self._jrdd = jrdd self.is_cached = False self.is_checkpointed = False + self._is_barrier = False self.ctx = ctx self._jrdd_deserializer = jrdd_deserializer self._id = jrdd.id() @@ -331,6 +332,10 @@ def getCheckpointFile(self): if checkpointFile.isDefined(): return checkpointFile.get() + def barrier(self): + self._is_barrier = True + return self + def map(self, f, preservesPartitioning=False): """ Return a new RDD by applying a function to each element of this RDD. @@ -2461,6 +2466,7 @@ def pipeline_func(split, iterator): prev.preservesPartitioning and preservesPartitioning self._prev_jrdd = prev._prev_jrdd # maintain the pipeline self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer + self._is_barrier = prev._is_barrier self.is_cached = False self.is_checkpointed = False self.ctx = prev.ctx @@ -2469,7 +2475,8 @@ def pipeline_func(split, iterator): self._id = None self._jrdd_deserializer = self.ctx.serializer self._bypass_serializer = False - self.partitioner = prev.partitioner if self.preservesPartitioning else None + self.partitioner = \ + prev.partitioner if self.preservesPartitioning else None def getNumPartitions(self): return self._prev_jrdd.partitions().size() @@ -2490,7 +2497,10 @@ def _jrdd(self): self._jrdd_deserializer, profiler) python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), wrapped_func, self.preservesPartitioning) - self._jrdd_val = python_rdd.asJavaRDD() + if (self._is_barrier): + self._jrdd_val = python_rdd.asJavaRDD().barrier() + else: + self._jrdd_val = python_rdd.asJavaRDD() if profiler: self._id = self._jrdd_val.id() diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 15753f77bd90..f6ee7ee1084f 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -706,6 +706,13 @@ def write_int(value, stream): stream.write(struct.pack("!i", value)) +def read_bool(stream): + length = stream.read(1) + if not length: + raise EOFError + return struct.unpack("!?", length)[0] + + def write_with_length(obj, stream): write_int(len(obj), stream) stream.write(obj) diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index 63ae1f30e17c..656d677d93f3 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -29,6 +29,7 @@ class TaskContext(object): """ _taskContext = None + _javaContext = None _attemptNumber = None _partitionId = None @@ -95,3 +96,16 @@ def getLocalProperty(self, key): Get a local property set upstream in the driver, or None if it is missing. """ return self._localProperties.get(key, None) + + def barrier(self): + if self._javaContext is None: + raise Exception("not barrier") + else: + self._javaContext.barrier() + + def hosts(self): + if self._javaContext is None: + raise Exception("not barrier") + else: + java_list = self._javaContext.hosts() + return [h for h in java_list] diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index fbcb8af8bfb2..79640275c7da 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -25,13 +25,15 @@ import socket import traceback +from py4j.java_gateway import JavaGateway, GatewayParameters + from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.java_gateway import do_server_auth from pyspark.taskcontext import TaskContext from pyspark.files import SparkFiles from pyspark.rdd import PythonEvalType -from pyspark.serializers import write_with_length, write_int, read_long, \ +from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer from pyspark.sql.types import to_arrow_type @@ -218,6 +220,8 @@ def main(infile, outfile): # initialize global state taskContext = TaskContext._getOrCreate() + isBarrier = read_bool(infile) + taskContext._stageId = read_int(infile) taskContext._partitionId = read_int(infile) taskContext._attemptNumber = read_int(infile) @@ -232,6 +236,13 @@ def main(infile, outfile): shuffle.DiskBytesSpilled = 0 _accumulatorRegistry.clear() + if (isBarrier): + port = 25333 + 2 + 2 * taskContext._partitionId + paras = GatewayParameters(port=port) + taskContext._javaContext = \ + JavaGateway(python_proxy_port=port+1, + gateway_parameters=paras).entry_point + # fetch name of workdir spark_files_dir = utf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala index 2d2f90c63a30..31f84310485a 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala @@ -253,6 +253,7 @@ class MesosFineGrainedSchedulerBackendSuite executorId = "s1", name = "n1", index = 0, + partitionId = 0, addedFiles = mutable.Map.empty[String, Long], addedJars = mutable.Map.empty[String, Long], properties = new Properties(), @@ -361,6 +362,7 @@ class MesosFineGrainedSchedulerBackendSuite executorId = "s1", name = "n1", index = 0, + partitionId = 0, addedFiles = mutable.Map.empty[String, Long], addedJars = mutable.Map.empty[String, Long], properties = new Properties(),