diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala new file mode 100644 index 000000000000..4c358629dee9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.annotation.{Experimental, Since} + +/** A [[TaskContext]] with extra info and tooling for a barrier stage. */ +trait BarrierTaskContext extends TaskContext { + + /** + * :: Experimental :: + * Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to + * MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same + * stage have reached this routine. + */ + @Experimental + @Since("2.4.0") + def barrier(): Unit + + /** + * :: Experimental :: + * Returns the all task infos in this barrier stage, the task infos are ordered by partitionId. + */ + @Experimental + @Since("2.4.0") + def getTaskInfos(): Array[BarrierTaskInfo] +} diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala new file mode 100644 index 000000000000..8ac705757a38 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.util.Properties + +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.metrics.MetricsSystem + +/** A [[BarrierTaskContext]] implementation. */ +private[spark] class BarrierTaskContextImpl( + override val stageId: Int, + override val stageAttemptNumber: Int, + override val partitionId: Int, + override val taskAttemptId: Long, + override val attemptNumber: Int, + override val taskMemoryManager: TaskMemoryManager, + localProperties: Properties, + @transient private val metricsSystem: MetricsSystem, + // The default value is only used in tests. + override val taskMetrics: TaskMetrics = TaskMetrics.empty) + extends TaskContextImpl(stageId, stageAttemptNumber, partitionId, taskAttemptId, attemptNumber, + taskMemoryManager, localProperties, metricsSystem, taskMetrics) + with BarrierTaskContext { + + // TODO SPARK-24817 implement global barrier. + override def barrier(): Unit = {} + + override def getTaskInfos(): Array[BarrierTaskInfo] = { + val addressesStr = localProperties.getProperty("addresses", "") + addressesStr.split(",").map(_.trim()).map(new BarrierTaskInfo(_)) + } +} diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskInfo.scala b/core/src/main/scala/org/apache/spark/BarrierTaskInfo.scala new file mode 100644 index 000000000000..ce2653df2e84 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/BarrierTaskInfo.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.annotation.{Experimental, Since} + + +/** + * :: Experimental :: + * Carries all task infos of a barrier task. + * + * @param address the IPv4 address(host:port) of the executor that a barrier task is running on + */ +@Experimental +@Since("2.4.0") +class BarrierTaskInfo(val address: String) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 73646051f264..1c4fa4bc6541 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -434,6 +434,18 @@ private[spark] class MapOutputTrackerMaster( } } + /** Unregister all map output information of the given shuffle. */ + def unregisterAllMapOutput(shuffleId: Int) { + shuffleStatuses.get(shuffleId) match { + case Some(shuffleStatus) => + shuffleStatus.removeOutputsByFilter(x => true) + incrementEpoch() + case None => + throw new SparkException( + s"unregisterAllMapOutput called for nonexistent shuffle ID $shuffleId.") + } + } + /** Unregister shuffle data */ def unregisterShuffle(shuffleId: Int) { shuffleStatuses.remove(shuffleId).foreach { shuffleStatus => diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala index e4587c96eae1..904d9c025629 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala @@ -23,11 +23,21 @@ import org.apache.spark.{Partition, TaskContext} /** * An RDD that applies the provided function to every partition of the parent RDD. + * + * @param prev the parent RDD. + * @param f The function used to map a tuple of (TaskContext, partition index, input iterator) to + * an output iterator. + * @param preservesPartitioning Whether the input function preserves the partitioner, which should + * be `false` unless `prev` is a pair RDD and the input function + * doesn't modify the keys. + * @param isFromBarrier Indicates whether this RDD is transformed from an RDDBarrier, a stage + * containing at least one RDDBarrier shall be turned into a barrier stage. */ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( var prev: RDD[T], f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator) - preservesPartitioning: Boolean = false) + preservesPartitioning: Boolean = false, + isFromBarrier: Boolean = false) extends RDD[U](prev) { override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None @@ -41,4 +51,7 @@ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( super.clearDependencies() prev = null } + + @transient protected lazy override val isBarrier_ : Boolean = + isFromBarrier || dependencies.exists(_.rdd.isBarrier()) } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 0574abdca32a..cbc1143126d8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -33,7 +33,7 @@ import org.apache.hadoop.mapred.TextOutputFormat import org.apache.spark._ import org.apache.spark.Partitioner._ -import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.partial.BoundedDouble @@ -1647,6 +1647,14 @@ abstract class RDD[T: ClassTag]( } } + /** + * :: Experimental :: + * Indicates that Spark must launch the tasks together for the current stage. + */ + @Experimental + @Since("2.4.0") + def barrier(): RDDBarrier[T] = withScope(new RDDBarrier[T](this)) + // ======================================================================= // Other internal methods and fields // ======================================================================= @@ -1839,6 +1847,23 @@ abstract class RDD[T: ClassTag]( def toJavaRDD() : JavaRDD[T] = { new JavaRDD(this)(elementClassTag) } + + /** + * Whether the RDD is in a barrier stage. Spark must launch all the tasks at the same time for a + * barrier stage. + * + * An RDD is in a barrier stage, if at least one of its parent RDD(s), or itself, are mapped from + * an [[RDDBarrier]]. This function always returns false for a [[ShuffledRDD]], since a + * [[ShuffledRDD]] indicates start of a new stage. + * + * A [[MapPartitionsRDD]] can be transformed from an [[RDDBarrier]], under that case the + * [[MapPartitionsRDD]] shall be marked as barrier. + */ + private[spark] def isBarrier(): Boolean = isBarrier_ + + // From performance concern, cache the value to avoid repeatedly compute `isBarrier()` on a long + // RDD chain. + @transient protected lazy val isBarrier_ : Boolean = dependencies.exists(_.rdd.isBarrier()) } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala new file mode 100644 index 000000000000..85565d16e271 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.reflect.ClassTag + +import org.apache.spark.BarrierTaskContext +import org.apache.spark.TaskContext +import org.apache.spark.annotation.{Experimental, Since} + +/** Represents an RDD barrier, which forces Spark to launch tasks of this stage together. */ +class RDDBarrier[T: ClassTag](rdd: RDD[T]) { + + /** + * :: Experimental :: + * Maps partitions together with a provided BarrierTaskContext. + * + * `preservesPartitioning` indicates whether the input function preserves the partitioner, which + * should be `false` unless `rdd` is a pair RDD and the input function doesn't modify the keys. + */ + @Experimental + @Since("2.4.0") + def mapPartitions[S: ClassTag]( + f: (Iterator[T], BarrierTaskContext) => Iterator[S], + preservesPartitioning: Boolean = false): RDD[S] = rdd.withScope { + val cleanedF = rdd.sparkContext.clean(f) + new MapPartitionsRDD( + rdd, + (context: TaskContext, index: Int, iter: Iterator[T]) => + cleanedF(iter, context.asInstanceOf[BarrierTaskContext]), + preservesPartitioning, + isFromBarrier = true + ) + } + + /** TODO extra conf(e.g. timeout) */ +} diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index 26eaa9aa3d03..e8f9b27b7eb5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -110,4 +110,6 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag]( super.clearDependencies() prev = null } + + private[spark] override def isBarrier(): Boolean = false } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala index 949e88f60627..6e4d062749d5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ActiveJob.scala @@ -60,4 +60,10 @@ private[spark] class ActiveJob( val finished = Array.fill[Boolean](numPartitions)(false) var numFinished = 0 + + /** Resets the status of all partitions in this stage so they are marked as not finished. */ + def resetAllPartitions(): Unit = { + (0 until numPartitions).foreach(finished.update(_, false)) + numFinished = 0 + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index f74425d73b39..003d64f78e85 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1062,7 +1062,7 @@ class DAGScheduler( stage.pendingPartitions += id new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId), - Option(sc.applicationId), sc.applicationAttemptId) + Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier()) } case stage: ResultStage => @@ -1072,7 +1072,8 @@ class DAGScheduler( val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptNumber, taskBinary, part, locs, id, properties, serializedTaskMetrics, - Option(jobId), Option(sc.applicationId), sc.applicationAttemptId) + Option(jobId), Option(sc.applicationId), sc.applicationAttemptId, + stage.rdd.isBarrier()) } } } catch { @@ -1311,17 +1312,6 @@ class DAGScheduler( } } - case Resubmitted => - logInfo("Resubmitted " + task + ", so marking it as still running") - stage match { - case sms: ShuffleMapStage => - sms.pendingPartitions += task.partitionId - - case _ => - assert(false, "TaskSetManagers should only send Resubmitted task statuses for " + - "tasks in ShuffleMapStages.") - } - case FetchFailed(bmAddress, shuffleId, mapId, _, failureMessage) => val failedStage = stageIdToStage(task.stageId) val mapStage = shuffleIdToMapStage(shuffleId) @@ -1331,9 +1321,9 @@ class DAGScheduler( s" ${task.stageAttemptId} and there is a more recent attempt for that stage " + s"(attempt ${failedStage.latestInfo.attemptNumber}) running") } else { - failedStage.fetchFailedAttemptIds.add(task.stageAttemptId) + failedStage.failedAttemptIds.add(task.stageAttemptId) val shouldAbortStage = - failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts || + failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || disallowStageRetryForTest // It is likely that we receive multiple FetchFailed for a single stage (because we have @@ -1349,6 +1339,29 @@ class DAGScheduler( s"longer running") } + if (mapStage.rdd.isBarrier()) { + // Mark all the map as broken in the map stage, to ensure retry all the tasks on + // resubmitted stage attempt. + mapOutputTracker.unregisterAllMapOutput(shuffleId) + } else if (mapId != -1) { + // Mark the map whose fetch failed as broken in the map stage + mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) + } + + if (failedStage.rdd.isBarrier()) { + failedStage match { + case failedMapStage: ShuffleMapStage => + // Mark all the map as broken in the map stage, to ensure retry all the tasks on + // resubmitted stage attempt. + mapOutputTracker.unregisterAllMapOutput(failedMapStage.shuffleDep.shuffleId) + + case failedResultStage: ResultStage => + // Mark all the partitions of the result stage to be not finished, to ensure retry + // all the tasks on resubmitted stage attempt. + failedResultStage.activeJob.map(_.resetAllPartitions()) + } + } + if (shouldAbortStage) { val abortMessage = if (disallowStageRetryForTest) { "Fetch failure will not retry stage due to testing config" @@ -1375,7 +1388,7 @@ class DAGScheduler( // simpler while not producing an overwhelming number of scheduler events. logInfo( s"Resubmitting $mapStage (${mapStage.name}) and " + - s"$failedStage (${failedStage.name}) due to fetch failure" + s"$failedStage (${failedStage.name}) due to fetch failure" ) messageScheduler.schedule( new Runnable { @@ -1386,10 +1399,6 @@ class DAGScheduler( ) } } - // Mark the map whose fetch failed as broken in the map stage - if (mapId != -1) { - mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) - } // TODO: mark the executor as failed only if there were lots of fetch failures on it if (bmAddress != null) { @@ -1411,6 +1420,76 @@ class DAGScheduler( } } + case failure: TaskFailedReason if task.isBarrier => + // Also handle the task failed reasons here. + failure match { + case Resubmitted => + handleResubmittedFailure(task, stage) + + case _ => // Do nothing. + } + + // Always fail the current stage and retry all the tasks when a barrier task fail. + val failedStage = stageIdToStage(task.stageId) + logInfo(s"Marking $failedStage (${failedStage.name}) as failed due to a barrier task " + + "failed.") + val message = s"Stage failed because barrier task $task finished unsuccessfully. " + + failure.toErrorString + try { + // cancelTasks will fail if a SchedulerBackend does not implement killTask + taskScheduler.cancelTasks(stageId, interruptThread = false) + } catch { + case e: UnsupportedOperationException => + // Cannot continue with barrier stage if failed to cancel zombie barrier tasks. + // TODO SPARK-24877 leave the zombie tasks and ignore their completion events. + logWarning(s"Could not cancel tasks for stage $stageId", e) + abortStage(failedStage, "Could not cancel zombie barrier tasks for stage " + + s"$failedStage (${failedStage.name})", Some(e)) + } + markStageAsFinished(failedStage, Some(message)) + + failedStage.failedAttemptIds.add(task.stageAttemptId) + // TODO Refactor the failure handling logic to combine similar code with that of + // FetchFailed. + val shouldAbortStage = + failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts || + disallowStageRetryForTest + + if (shouldAbortStage) { + val abortMessage = if (disallowStageRetryForTest) { + "Barrier stage will not retry stage due to testing config" + } else { + s"""$failedStage (${failedStage.name}) + |has failed the maximum allowable number of + |times: $maxConsecutiveStageAttempts. + |Most recent failure reason: $message""".stripMargin.replaceAll("\n", " ") + } + abortStage(failedStage, abortMessage, None) + } else { + failedStage match { + case failedMapStage: ShuffleMapStage => + // Mark all the map as broken in the map stage, to ensure retry all the tasks on + // resubmitted stage attempt. + mapOutputTracker.unregisterAllMapOutput(failedMapStage.shuffleDep.shuffleId) + + case failedResultStage: ResultStage => + // Mark all the partitions of the result stage to be not finished, to ensure retry + // all the tasks on resubmitted stage attempt. + failedResultStage.activeJob.map(_.resetAllPartitions()) + } + + // update failedStages and make sure a ResubmitFailedStages event is enqueued + failedStages += failedStage + logInfo(s"Resubmitting $failedStage (${failedStage.name}) due to barrier stage " + + "failure.") + messageScheduler.schedule(new Runnable { + override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages) + }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS) + } + + case Resubmitted => + handleResubmittedFailure(task, stage) + case _: TaskCommitDenied => // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits @@ -1426,6 +1505,18 @@ class DAGScheduler( } } + private def handleResubmittedFailure(task: Task[_], stage: Stage): Unit = { + logInfo(s"Resubmitted $task, so marking it as still running.") + stage match { + case sms: ShuffleMapStage => + sms.pendingPartitions += task.partitionId + + case _ => + throw new SparkException("TaskSetManagers should only send Resubmitted task " + + "statuses for tasks in ShuffleMapStages.") + } + } + private[scheduler] def markMapStageJobsAsFinished(shuffleStage: ShuffleMapStage): Unit = { // Mark any map-stage jobs waiting on this stage as finished if (shuffleStage.isAvailable && shuffleStage.mapStageJobs.nonEmpty) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index e36c759a4255..aafeae05b566 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -48,7 +48,9 @@ import org.apache.spark.rdd.RDD * @param jobId id of the job this task belongs to * @param appId id of the app this task belongs to * @param appAttemptId attempt id of the app this task belongs to - */ + * @param isBarrier whether this task belongs to a barrier stage. Spark must launch all the tasks + * at the same time for a barrier stage. + */ private[spark] class ResultTask[T, U]( stageId: Int, stageAttemptId: Int, @@ -60,9 +62,10 @@ private[spark] class ResultTask[T, U]( serializedTaskMetrics: Array[Byte], jobId: Option[Int] = None, appId: Option[String] = None, - appAttemptId: Option[String] = None) + appAttemptId: Option[String] = None, + isBarrier: Boolean = false) extends Task[U](stageId, stageAttemptId, partition.index, localProperties, serializedTaskMetrics, - jobId, appId, appAttemptId) + jobId, appId, appAttemptId, isBarrier) with Serializable { @transient private[this] val preferredLocs: Seq[TaskLocation] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 7a25c47e2cab..f2cd65fd523a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -49,6 +49,8 @@ import org.apache.spark.shuffle.ShuffleWriter * @param jobId id of the job this task belongs to * @param appId id of the app this task belongs to * @param appAttemptId attempt id of the app this task belongs to + * @param isBarrier whether this task belongs to a barrier stage. Spark must launch all the tasks + * at the same time for a barrier stage. */ private[spark] class ShuffleMapTask( stageId: Int, @@ -60,9 +62,10 @@ private[spark] class ShuffleMapTask( serializedTaskMetrics: Array[Byte], jobId: Option[Int] = None, appId: Option[String] = None, - appAttemptId: Option[String] = None) + appAttemptId: Option[String] = None, + isBarrier: Boolean = false) extends Task[MapStatus](stageId, stageAttemptId, partition.index, localProperties, - serializedTaskMetrics, jobId, appId, appAttemptId) + serializedTaskMetrics, jobId, appId, appAttemptId, isBarrier) with Logging { /** A constructor used only in test suites. This does not require passing in an RDD. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 290fd073caf2..26cca334d3bd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -82,15 +82,15 @@ private[scheduler] abstract class Stage( private var _latestInfo: StageInfo = StageInfo.fromStage(this, nextAttemptId) /** - * Set of stage attempt IDs that have failed with a FetchFailure. We keep track of these - * failures in order to avoid endless retries if a stage keeps failing with a FetchFailure. + * Set of stage attempt IDs that have failed. We keep track of these failures in order to avoid + * endless retries if a stage keeps failing. * We keep track of each attempt ID that has failed to avoid recording duplicate failures if * multiple tasks from the same stage attempt fail (SPARK-5945). */ - val fetchFailedAttemptIds = new HashSet[Int] + val failedAttemptIds = new HashSet[Int] private[scheduler] def clearFailures() : Unit = { - fetchFailedAttemptIds.clear() + failedAttemptIds.clear() } /** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index f536fc2a5f0a..89ff2038e5f8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -49,6 +49,8 @@ import org.apache.spark.util._ * @param jobId id of the job this task belongs to * @param appId id of the app this task belongs to * @param appAttemptId attempt id of the app this task belongs to + * @param isBarrier whether this task belongs to a barrier stage. Spark must launch all the tasks + * at the same time for a barrier stage. */ private[spark] abstract class Task[T]( val stageId: Int, @@ -60,7 +62,8 @@ private[spark] abstract class Task[T]( SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array(), val jobId: Option[Int] = None, val appId: Option[String] = None, - val appAttemptId: Option[String] = None) extends Serializable { + val appAttemptId: Option[String] = None, + val isBarrier: Boolean = false) extends Serializable { @transient lazy val metrics: TaskMetrics = SparkEnv.get.closureSerializer.newInstance().deserialize(ByteBuffer.wrap(serializedTaskMetrics)) @@ -77,16 +80,32 @@ private[spark] abstract class Task[T]( attemptNumber: Int, metricsSystem: MetricsSystem): T = { SparkEnv.get.blockManager.registerTask(taskAttemptId) - context = new TaskContextImpl( - stageId, - stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal - partitionId, - taskAttemptId, - attemptNumber, - taskMemoryManager, - localProperties, - metricsSystem, - metrics) + // TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether + // the stage is barrier. + context = if (isBarrier) { + new BarrierTaskContextImpl( + stageId, + stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal + partitionId, + taskAttemptId, + attemptNumber, + taskMemoryManager, + localProperties, + metricsSystem, + metrics) + } else { + new TaskContextImpl( + stageId, + stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal + partitionId, + taskAttemptId, + attemptNumber, + taskMemoryManager, + localProperties, + metricsSystem, + metrics) + } + TaskContext.setTaskContext(context) taskThread = Thread.currentThread() diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala index c98b87148e40..bb4a4442b943 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala @@ -50,6 +50,7 @@ private[spark] class TaskDescription( val executorId: String, val name: String, val index: Int, // Index within this task's TaskSet + val partitionId: Int, val addedFiles: Map[String, Long], val addedJars: Map[String, Long], val properties: Properties, @@ -76,6 +77,7 @@ private[spark] object TaskDescription { dataOut.writeUTF(taskDescription.executorId) dataOut.writeUTF(taskDescription.name) dataOut.writeInt(taskDescription.index) + dataOut.writeInt(taskDescription.partitionId) // Write files. serializeStringLongMap(taskDescription.addedFiles, dataOut) @@ -117,6 +119,7 @@ private[spark] object TaskDescription { val executorId = dataIn.readUTF() val name = dataIn.readUTF() val index = dataIn.readInt() + val partitionId = dataIn.readInt() // Read files. val taskFiles = deserializeStringLongMap(dataIn) @@ -138,7 +141,7 @@ private[spark] object TaskDescription { // Create a sub-buffer for the serialized task into its own buffer (to be deserialized later). val serializedTask = byteBuffer.slice() - new TaskDescription(taskId, attemptNumber, executorId, name, index, taskFiles, taskJars, - properties, serializedTask) + new TaskDescription(taskId, attemptNumber, executorId, name, index, partitionId, taskFiles, + taskJars, properties, serializedTask) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 56c0bf6c0935..587ed4b5243b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -274,7 +274,8 @@ private[spark] class TaskSchedulerImpl( maxLocality: TaskLocality, shuffledOffers: Seq[WorkerOffer], availableCpus: Array[Int], - tasks: IndexedSeq[ArrayBuffer[TaskDescription]]) : Boolean = { + tasks: IndexedSeq[ArrayBuffer[TaskDescription]], + addressesWithDescs: ArrayBuffer[(String, TaskDescription)]) : Boolean = { var launchedTask = false // nodes and executors that are blacklisted for the entire application have already been // filtered out by this point @@ -291,6 +292,11 @@ private[spark] class TaskSchedulerImpl( executorIdToRunningTaskIds(execId).add(tid) availableCpus(i) -= CPUS_PER_TASK assert(availableCpus(i) >= 0) + // Only update hosts for a barrier task. + if (taskSet.isBarrier) { + // The executor address is expected to be non empty. + addressesWithDescs += (shuffledOffers(i).address.get -> task) + } launchedTask = true } } catch { @@ -346,6 +352,7 @@ private[spark] class TaskSchedulerImpl( // Build a list of tasks to assign to each worker. val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores / CPUS_PER_TASK)) val availableCpus = shuffledOffers.map(o => o.cores).toArray + val availableSlots = shuffledOffers.map(o => o.cores / CPUS_PER_TASK).sum val sortedTaskSets = rootPool.getSortedTaskSetQueue for (taskSet <- sortedTaskSets) { logDebug("parentName: %s, name: %s, runningTasks: %s".format( @@ -359,20 +366,55 @@ private[spark] class TaskSchedulerImpl( // of locality levels so that it gets a chance to launch local tasks on all of them. // NOTE: the preferredLocality order: PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY for (taskSet <- sortedTaskSets) { - var launchedAnyTask = false - var launchedTaskAtCurrentMaxLocality = false - for (currentMaxLocality <- taskSet.myLocalityLevels) { - do { - launchedTaskAtCurrentMaxLocality = resourceOfferSingleTaskSet( - taskSet, currentMaxLocality, shuffledOffers, availableCpus, tasks) - launchedAnyTask |= launchedTaskAtCurrentMaxLocality - } while (launchedTaskAtCurrentMaxLocality) - } - if (!launchedAnyTask) { - taskSet.abortIfCompletelyBlacklisted(hostToExecutors) + // Skip the barrier taskSet if the available slots are less than the number of pending tasks. + if (taskSet.isBarrier && availableSlots < taskSet.numTasks) { + // Skip the launch process. + // TODO SPARK-24819 If the job requires more slots than available (both busy and free + // slots), fail the job on submit. + logInfo(s"Skip current round of resource offers for barrier stage ${taskSet.stageId} " + + s"because the barrier taskSet requires ${taskSet.numTasks} slots, while the total " + + s"number of available slots is $availableSlots.") + } else { + var launchedAnyTask = false + // Record all the executor IDs assigned barrier tasks on. + val addressesWithDescs = ArrayBuffer[(String, TaskDescription)]() + for (currentMaxLocality <- taskSet.myLocalityLevels) { + var launchedTaskAtCurrentMaxLocality = false + do { + launchedTaskAtCurrentMaxLocality = resourceOfferSingleTaskSet(taskSet, + currentMaxLocality, shuffledOffers, availableCpus, tasks, addressesWithDescs) + launchedAnyTask |= launchedTaskAtCurrentMaxLocality + } while (launchedTaskAtCurrentMaxLocality) + } + if (!launchedAnyTask) { + taskSet.abortIfCompletelyBlacklisted(hostToExecutors) + } + if (launchedAnyTask && taskSet.isBarrier) { + // Check whether the barrier tasks are partially launched. + // TODO SPARK-24818 handle the assert failure case (that can happen when some locality + // requirements are not fulfilled, and we should revert the launched tasks). + require(addressesWithDescs.size == taskSet.numTasks, + s"Skip current round of resource offers for barrier stage ${taskSet.stageId} " + + s"because only ${addressesWithDescs.size} out of a total number of " + + s"${taskSet.numTasks} tasks got resource offers. The resource offers may have " + + "been blacklisted or cannot fulfill task locality requirements.") + + // Update the taskInfos into all the barrier task properties. + val addressesStr = addressesWithDescs + // Addresses ordered by partitionId + .sortBy(_._2.partitionId) + .map(_._1) + .mkString(",") + addressesWithDescs.foreach(_._2.properties.setProperty("addresses", addressesStr)) + + logInfo(s"Successfully scheduled all the ${addressesWithDescs.size} tasks for barrier " + + s"stage ${taskSet.stageId}.") + } } } + // TODO SPARK-24823 Cancel a job that contains barrier stage(s) if the barrier tasks don't get + // launched within a configured time. if (tasks.size > 0) { hasLaunchedTask = true } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index defed1e0f9c6..0b21256ab6cc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -123,6 +123,10 @@ private[spark] class TaskSetManager( // TODO: We should kill any running task attempts when the task set manager becomes a zombie. private[scheduler] var isZombie = false + // Whether the taskSet run tasks from a barrier stage. Spark must launch all the tasks at the + // same time for a barrier stage. + private[scheduler] def isBarrier = taskSet.tasks.nonEmpty && taskSet.tasks(0).isBarrier + // Set of pending tasks for each executor. These collections are actually // treated as stacks, in which new tasks are added to the end of the // ArrayBuffer and removed from the end. This makes it faster to detect @@ -512,6 +516,7 @@ private[spark] class TaskSetManager( execId, taskName, index, + task.partitionId, addedFiles, addedJars, task.localProperties, @@ -979,8 +984,8 @@ private[spark] class TaskSetManager( */ override def checkSpeculatableTasks(minTimeToSpeculation: Int): Boolean = { // Can't speculate if we only have one task, and no need to speculate if the task set is a - // zombie. - if (isZombie || numTasks == 1) { + // zombie or is from a barrier stage. + if (isZombie || isBarrier || numTasks == 1) { return false } var foundTasks = false diff --git a/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala index 810b36cddf83..6ec74913e42f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/WorkerOffer.scala @@ -21,4 +21,10 @@ package org.apache.spark.scheduler * Represents free resources available on an executor. */ private[spark] -case class WorkerOffer(executorId: String, host: String, cores: Int) +case class WorkerOffer( + executorId: String, + host: String, + cores: Int, + // `address` is an optional hostPort string, it provide more useful information than `host` + // when multiple executors are launched on the same host. + address: Option[String] = None) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 9b90e309d2e0..375aeb0c3466 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -242,7 +242,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val activeExecutors = executorDataMap.filterKeys(executorIsAlive) val workOffers = activeExecutors.map { case (id, executorData) => - new WorkerOffer(id, executorData.executorHost, executorData.freeCores) + new WorkerOffer(id, executorData.executorHost, executorData.freeCores, + Some(executorData.executorAddress.hostPort)) }.toIndexedSeq scheduler.resourceOffers(workOffers) } @@ -267,7 +268,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp if (executorIsAlive(executorId)) { val executorData = executorDataMap(executorId) val workOffers = IndexedSeq( - new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)) + new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores, + Some(executorData.executorAddress.hostPort))) scheduler.resourceOffers(workOffers) } else { Seq.empty diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala index 4c614c5c0f60..cf8b0ff4f701 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala @@ -81,7 +81,8 @@ private[spark] class LocalEndpoint( } def reviveOffers() { - val offers = IndexedSeq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) + val offers = IndexedSeq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores, + Some(rpcEnv.address.hostPort))) for (task <- scheduler.resourceOffers(offers).flatten) { freeCores -= scheduler.CPUS_PER_TASK executor.launchTask(executorBackend, task) diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index ce9f2be1c02d..e5f31a04c6e7 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -627,6 +627,48 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu assert(exc.getCause() != null) stream.close() } + + test("support barrier execution mode under local mode") { + val conf = new SparkConf().setAppName("test").setMaster("local[2]") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(Seq(1, 2, 3, 4), 2) + val rdd2 = rdd.barrier().mapPartitions { (it, context) => + // If we don't get the expected taskInfos, the job shall abort due to stage failure. + if (context.getTaskInfos().length != 2) { + throw new SparkException("Expected taksInfos length is 2, actual length is " + + s"${context.getTaskInfos().length}.") + } + context.barrier() + it + } + rdd2.collect() + + eventually(timeout(10.seconds)) { + assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) + } + } + + test("support barrier execution mode under local-cluster mode") { + val conf = new SparkConf() + .setMaster("local-cluster[3, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(Seq(1, 2, 3, 4), 2) + val rdd2 = rdd.barrier().mapPartitions { (it, context) => + // If we don't get the expected taskInfos, the job shall abort due to stage failure. + if (context.getTaskInfos().length != 2) { + throw new SparkException("Expected taksInfos length is 2, actual length is " + + s"${context.getTaskInfos().length}.") + } + context.barrier() + it + } + rdd2.collect() + + eventually(timeout(10.seconds)) { + assert(sc.statusTracker.getExecutorInfos.map(_.numRunningTasks()).sum == 0) + } + } } object SparkContextSuite { diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 1a7bebe2c53c..77a7668d3a1d 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -275,6 +275,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug executorId = "", name = "", index = 0, + partitionId = 0, addedFiles = Map[String, Long](), addedJars = Map[String, Long](), properties = new Properties, diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala new file mode 100644 index 000000000000..39d4618e4c6c --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/RDDBarrierSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.{SharedSparkContext, SparkFunSuite} + +class RDDBarrierSuite extends SparkFunSuite with SharedSparkContext { + + test("create an RDDBarrier") { + val rdd = sc.parallelize(1 to 10, 4) + assert(rdd.isBarrier() === false) + + val rdd2 = rdd.barrier().mapPartitions((iter, context) => iter) + assert(rdd2.isBarrier() === true) + } + + test("create an RDDBarrier in the middle of a chain of RDDs") { + val rdd = sc.parallelize(1 to 10, 4).map(x => x * 2) + val rdd2 = rdd.barrier().mapPartitions((iter, context) => iter).map(x => (x, x + 1)) + assert(rdd2.isBarrier() === true) + } + + test("RDDBarrier with shuffle") { + val rdd = sc.parallelize(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions((iter, context) => iter).repartition(2) + assert(rdd2.isBarrier() === false) + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 2987170bf502..b3db5e29fb82 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1055,6 +1055,64 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(sparkListener.failedStages.size == 1) } + test("Retry all the tasks on a resubmitted attempt of a barrier stage caused by FetchFailure") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions((it, context) => it) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + submit(reduceRdd, Array(0, 1)) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)), + (Success, makeMapStatus("hostB", reduceRdd.partitions.length)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty)) + + // The first result task fails, with a fetch failure for the output from the first mapper. + runEvent(makeCompletionEvent( + taskSets(1).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + null)) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq(0, 1))) + + scheduler.resubmitFailedStages() + // Complete the map stage. + completeShuffleMapStageSuccessfully(0, 1, numShufflePartitions = 2) + + // Complete the result stage. + completeNextResultStageWithSuccess(1, 1) + + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assertDataStructuresEmpty() + } + + test("Retry all the tasks on a resubmitted attempt of a barrier stage caused by TaskKilled") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil).barrier().mapPartitions((it, context) => it) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + submit(reduceRdd, Array(0, 1)) + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", reduceRdd.partitions.length)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq(1))) + + // The second map task fails with TaskKilled. + runEvent(makeCompletionEvent( + taskSets(0).tasks(1), + TaskKilled("test"), + null)) + assert(sparkListener.failedStages === Seq(0)) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq(0, 1))) + + scheduler.resubmitFailedStages() + // Complete the map stage. + completeShuffleMapStageSuccessfully(0, 1, numShufflePartitions = 2) + + // Complete the result stage. + completeNextResultStageWithSuccess(1, 0) + + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assertDataStructuresEmpty() + } + /** * This tests the case where another FetchFailed comes in while the map stage is getting * re-run. diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index 109d4a0a870b..b29d32f7b35c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -27,8 +27,10 @@ class FakeTask( partitionId: Int, prefLocs: Seq[TaskLocation] = Nil, serializedTaskMetrics: Array[Byte] = - SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array()) - extends Task[Int](stageId, 0, partitionId, new Properties, serializedTaskMetrics) { + SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array(), + isBarrier: Boolean = false) + extends Task[Int](stageId, 0, partitionId, new Properties, serializedTaskMetrics, + isBarrier = isBarrier) { override def runTask(context: TaskContext): Int = 0 override def preferredLocations: Seq[TaskLocation] = prefLocs @@ -74,4 +76,22 @@ object FakeTask { } new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null) } + + def createBarrierTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { + createBarrierTaskSet(numTasks, stageId = 0, stageAttempId = 0, prefLocs: _*) + } + + def createBarrierTaskSet( + numTasks: Int, + stageId: Int, + stageAttempId: Int, + prefLocs: Seq[TaskLocation]*): TaskSet = { + if (prefLocs.size != 0 && prefLocs.size != numTasks) { + throw new IllegalArgumentException("Wrong number of task locations") + } + val tasks = Array.tabulate[Task[_]](numTasks) { i => + new FakeTask(stageId, i, if (prefLocs.size != 0) prefLocs(i) else Nil, isBarrier = true) + } + new TaskSet(tasks, stageId, stageAttempId, priority = 0, null) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala index 97487ce1d2ca..ba62eec0522d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala @@ -62,6 +62,7 @@ class TaskDescriptionSuite extends SparkFunSuite { executorId = "testExecutor", name = "task for test", index = 19, + partitionId = 1, originalFiles, originalJars, originalProperties, @@ -77,6 +78,7 @@ class TaskDescriptionSuite extends SparkFunSuite { assert(decodedTaskDescription.executorId === originalTaskDescription.executorId) assert(decodedTaskDescription.name === originalTaskDescription.name) assert(decodedTaskDescription.index === originalTaskDescription.index) + assert(decodedTaskDescription.partitionId === originalTaskDescription.partitionId) assert(decodedTaskDescription.addedFiles.equals(originalFiles)) assert(decodedTaskDescription.addedJars.equals(originalJars)) assert(decodedTaskDescription.properties.equals(originalTaskDescription.properties)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 33f2ea1c94e7..624384abcd71 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -1021,4 +1021,38 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B verify(blacklist).updateBlacklistForSuccessfulTaskSet(meq(0), meq(stageAttempt), anyObject()) } } + + test("don't schedule for a barrier taskSet if available slots are less than pending tasks") { + val taskCpus = 2 + val taskScheduler = setupScheduler("spark.task.cpus" -> taskCpus.toString) + + val numFreeCores = 3 + val workerOffers = IndexedSeq( + new WorkerOffer("executor0", "host0", numFreeCores, Some("192.168.0.101:49625")), + new WorkerOffer("executor1", "host1", numFreeCores, Some("192.168.0.101:49627"))) + val attempt1 = FakeTask.createBarrierTaskSet(3) + + // submit attempt 1, offer some resources, since the available slots are less than pending + // tasks, don't schedule barrier tasks on the resource offer. + taskScheduler.submitTasks(attempt1) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(0 === taskDescriptions.length) + } + + test("schedule tasks for a barrier taskSet if all tasks can be launched together") { + val taskCpus = 2 + val taskScheduler = setupScheduler("spark.task.cpus" -> taskCpus.toString) + + val numFreeCores = 3 + val workerOffers = IndexedSeq( + new WorkerOffer("executor0", "host0", numFreeCores, Some("192.168.0.101:49625")), + new WorkerOffer("executor1", "host1", numFreeCores, Some("192.168.0.101:49627")), + new WorkerOffer("executor2", "host2", numFreeCores, Some("192.168.0.101:49629"))) + val attempt1 = FakeTask.createBarrierTaskSet(3) + + // submit attempt 1, offer some resources, all tasks get launched together + taskScheduler.submitTasks(attempt1) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(3 === taskDescriptions.length) + } } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala index 2d2f90c63a30..31f84310485a 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala @@ -253,6 +253,7 @@ class MesosFineGrainedSchedulerBackendSuite executorId = "s1", name = "n1", index = 0, + partitionId = 0, addedFiles = mutable.Map.empty[String, Long], addedJars = mutable.Map.empty[String, Long], properties = new Properties(), @@ -361,6 +362,7 @@ class MesosFineGrainedSchedulerBackendSuite executorId = "s1", name = "n1", index = 0, + partitionId = 0, addedFiles = mutable.Map.empty[String, Long], addedJars = mutable.Map.empty[String, Long], properties = new Properties(),