diff --git a/core/src/main/scala/org/apache/spark/CommitDeniedException.scala b/core/src/main/scala/org/apache/spark/CommitDeniedException.scala new file mode 100644 index 000000000000..3a910d433926 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/CommitDeniedException.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * Exception thrown when a task attempts to commit output to Hadoop, but + * is denied by the driver. + */ +@DeveloperApi +class CommitDeniedException(msg: String, jobID: Int, splitID: Int, attemptID: Int) + extends Exception(msg) { + def toTaskEndReason(): TaskEndReason = new TaskCommitDenied(jobID, splitID, attemptID) +} + diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 4d418037bd33..5526eea39f64 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -20,6 +20,8 @@ package org.apache.spark import java.io.File import java.net.Socket +import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorActor + import scala.collection.JavaConversions._ import scala.collection.mutable import scala.util.Properties @@ -34,7 +36,7 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.nio.NioBlockTransferService -import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ @@ -67,6 +69,7 @@ class SparkEnv ( val sparkFilesDir: String, val metricsSystem: MetricsSystem, val shuffleMemoryManager: ShuffleMemoryManager, + val outputCommitCoordinator: OutputCommitCoordinator, val conf: SparkConf) extends Logging { private[spark] var isStopped = false @@ -86,6 +89,7 @@ class SparkEnv ( blockManager.stop() blockManager.master.stop() metricsSystem.stop() + outputCommitCoordinator.stop() actorSystem.shutdown() // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut // down, but let's call it anyway in case it gets fixed in a later release @@ -346,6 +350,10 @@ object SparkEnv extends Logging { "levels using the RDD.persist() method instead.") } + val outputCommitCoordinator = new OutputCommitCoordinator(conf) + val outputCommitCoordinatorActor = registerOrLookup("OutputCommitCoordinator", + new OutputCommitCoordinatorActor(outputCommitCoordinator)) + outputCommitCoordinator.coordinatorActor = Some(outputCommitCoordinatorActor) new SparkEnv( executorId, actorSystem, @@ -362,6 +370,7 @@ object SparkEnv extends Logging { sparkFilesDir, metricsSystem, shuffleMemoryManager, + outputCommitCoordinator, conf) } diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 40237596570d..10db43624330 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.HadoopRDD +import org.apache.spark.util.AkkaUtils /** * Internal helper class that saves an RDD using a Hadoop OutputFormat. @@ -106,18 +107,27 @@ class SparkHadoopWriter(@transient jobConf: JobConf) val taCtxt = getTaskContext() val cmtr = getOutputCommitter() if (cmtr.needsTaskCommit(taCtxt)) { - try { - cmtr.commitTask(taCtxt) - logInfo (taID + ": Committed") - } catch { - case e: IOException => { - logError("Error committing the output of task: " + taID.value, e) - cmtr.abortTask(taCtxt) - throw e + val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator + val canCommit = outputCommitCoordinator.canCommit(jobID, splitID, attemptID) + if (canCommit) { + try { + cmtr.commitTask(taCtxt) + logInfo (s"$taID: Committed") + } catch { + case e: IOException => { + logError("Error committing the output of task: " + taID.value, e) + cmtr.abortTask(taCtxt) + throw e + } } + } else { + val msg: String = s"$taID: Not committed because the driver did not authorize commit" + logInfo(msg) + cmtr.abortTask(taCtxt) + throw new CommitDeniedException(msg, jobID, splitID, attemptID) } } else { - logInfo ("No need to commit output of task: " + taID.value) + logInfo(s"No need to commit output of task because needsTaskCommit=false: ${taID.value}") } } diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index af5fd8e0ac00..b9ed05c7e1b1 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -146,6 +146,20 @@ case object TaskKilled extends TaskFailedReason { override def toErrorString: String = "TaskKilled (killed intentionally)" } +/** + * :: DeveloperApi :: + * Task requested the driver to commit, but was denied. + */ +@DeveloperApi +case class TaskCommitDenied( + jobID: Int, + splitID: Int, + attemptID: Int) + extends TaskFailedReason { + override def toErrorString: String = s"TaskCommitDenied (Driver denied task commit)" + + s" for job: $jobID, split: $splitID, attempt: $attemptID" +} + /** * :: DeveloperApi :: * The task failed because the executor that it was running on was lost. This may happen because diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 42566d1a1409..9c3f82e2b247 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -248,6 +248,11 @@ private[spark] class Executor( execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) } + case cDE: CommitDeniedException => { + val reason = cDE.toTaskEndReason + execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) + } + case t: Throwable => { // Attempt to exit cleanly by informing the driver of our failure. // If anything goes wrong (or this was a fatal exception), we will delegate to diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 1cfe98673773..81c6b2beb287 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -38,7 +38,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.storage._ -import org.apache.spark.util.{CallSite, EventLoop, SystemClock, Clock, Utils} +import org.apache.spark.util._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat /** @@ -63,7 +63,7 @@ class DAGScheduler( mapOutputTracker: MapOutputTrackerMaster, blockManagerMaster: BlockManagerMaster, env: SparkEnv, - clock: Clock = SystemClock) + clock: org.apache.spark.util.Clock = SystemClock) extends Logging { def this(sc: SparkContext, taskScheduler: TaskScheduler) = { @@ -126,6 +126,8 @@ class DAGScheduler( private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this) taskScheduler.setDAGScheduler(this) + private val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator + // Called by TaskScheduler to report task's starting. def taskStarted(task: Task[_], taskInfo: TaskInfo) { eventProcessLoop.post(BeginEvent(task, taskInfo)) @@ -808,6 +810,7 @@ class DAGScheduler( // will be posted, which should always come after a corresponding SparkListenerStageSubmitted // event. stage.latestInfo = StageInfo.fromStage(stage, Some(partitionsToCompute.size)) + outputCommitCoordinator.stageStart(stage.id) listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times. @@ -865,6 +868,7 @@ class DAGScheduler( } else { // Because we posted SparkListenerStageSubmitted earlier, we should post // SparkListenerStageCompleted here in case there are no tasks to run. + outputCommitCoordinator.stageEnd(stage.id) listenerBus.post(SparkListenerStageCompleted(stage.latestInfo)) logDebug("Stage " + stage + " is actually done; %b %d %d".format( stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) @@ -909,6 +913,9 @@ class DAGScheduler( val stageId = task.stageId val taskType = Utils.getFormattedClassName(task) + outputCommitCoordinator.taskCompleted(stageId, task.partitionId, + event.taskInfo.attempt, event.reason) + // The success case is dealt with separately below, since we need to compute accumulator // updates before posting. if (event.reason != Success) { @@ -921,6 +928,7 @@ class DAGScheduler( // Skip all the actions if the stage has been cancelled. return } + val stage = stageIdToStage(task.stageId) def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None) = { @@ -1073,6 +1081,9 @@ class DAGScheduler( handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch)) } + case TaskCommitDenied(jobID, splitID, attemptID) => + // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits + case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) => // Do nothing here, left up to the TaskScheduler to decide how to handle user failures diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala new file mode 100644 index 000000000000..5f6f6f3422b2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import scala.collection.mutable + +import akka.actor.{ActorRef, Actor} + +import org.apache.spark._ +import org.apache.spark.util.{AkkaUtils, ActorLogReceive} + +private[spark] sealed trait OutputCommitCoordinationMessage extends Serializable + +private[spark] case class StageStarted(stage: Int) extends OutputCommitCoordinationMessage +private[spark] case class StageEnded(stage: Int) extends OutputCommitCoordinationMessage +private[spark] case object StopCoordinator extends OutputCommitCoordinationMessage + +private[spark] case class AskPermissionToCommitOutput( + stage: Int, + task: Long, + taskAttempt: Long) + extends OutputCommitCoordinationMessage + +private[spark] case class TaskCompleted( + stage: Int, + task: Long, + attempt: Long, + reason: TaskEndReason) + extends OutputCommitCoordinationMessage + +/** + * Authority that decides whether tasks can commit output to HDFS. + * + * This lives on the driver, but the actor allows the tasks that commit + * to Hadoop to invoke it. + */ +private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { + + // Initialized by SparkEnv + var coordinatorActor: Option[ActorRef] = None + private val timeout = AkkaUtils.askTimeout(conf) + private val maxAttempts = AkkaUtils.numRetries(conf) + private val retryInterval = AkkaUtils.retryWaitMs(conf) + + private type StageId = Int + private type TaskId = Long + private type TaskAttemptId = Long + private type CommittersByStageMap = mutable.Map[StageId, mutable.Map[TaskId, TaskAttemptId]] + + private val authorizedCommittersByStage: CommittersByStageMap = mutable.Map() + + def stageStart(stage: StageId) { + sendToActor(StageStarted(stage)) + } + def stageEnd(stage: StageId) { + sendToActor(StageEnded(stage)) + } + + def canCommit( + stage: StageId, + task: TaskId, + attempt: TaskAttemptId): Boolean = { + askActor(AskPermissionToCommitOutput(stage, task, attempt)) + } + + def taskCompleted( + stage: StageId, + task: TaskId, + attempt: TaskAttemptId, + reason: TaskEndReason) { + sendToActor(TaskCompleted(stage, task, attempt, reason)) + } + + def stop() { + sendToActor(StopCoordinator) + coordinatorActor = None + authorizedCommittersByStage.foreach(_._2.clear) + authorizedCommittersByStage.clear + } + + private def handleStageStart(stage: StageId): Unit = { + authorizedCommittersByStage(stage) = mutable.HashMap[TaskId, TaskAttemptId]() + } + + private def handleStageEnd(stage: StageId): Unit = { + authorizedCommittersByStage.remove(stage) + } + + private def handleAskPermissionToCommit( + stage: StageId, + task: TaskId, + attempt: TaskAttemptId): + Boolean = { + authorizedCommittersByStage.get(stage) match { + case Some(authorizedCommitters) => + authorizedCommitters.get(stage) match { + case Some(existingCommitter) => + logDebug(s"Denying $attempt to commit for stage=$stage, task=$task; " + + s"existingCommitter = $existingCommitter") + false + case None => + logDebug(s"Authorizing $attempt to commit for stage=$stage, task=$task") + authorizedCommitters(task) = attempt + true + } + case None => + logDebug(s"Stage $stage has completed, so not allowing task attempt $attempt to commit") + return false + } + } + + private def handleTaskCompletion( + stage: StageId, + task: TaskId, + attempt: TaskAttemptId, + reason: TaskEndReason): Unit = { + authorizedCommittersByStage.get(stage) match { + case Some(authorizedCommitters) => + reason match { + case Success => return + case TaskCommitDenied(jobID, splitID, attemptID) => + logInfo(s"Task was denied committing, stage: $stage, taskId: $task, attempt: $attempt") + case otherReason => + logDebug(s"Authorized committer $attempt (stage=$stage, task=$task) failed;" + + s" clearing lock") + authorizedCommitters.remove(task) + } + case None => + logDebug(s"Ignoring task completion for completed stage") + } + } + + private def sendToActor(msg: OutputCommitCoordinationMessage) { + coordinatorActor.foreach(_ ! msg) + } + + private def askActor(msg: OutputCommitCoordinationMessage): Boolean = { + coordinatorActor + .map(AkkaUtils.askWithReply[Boolean](msg, _, maxAttempts, retryInterval, timeout)) + .getOrElse(false) + } +} + +private[spark] object OutputCommitCoordinator { + + class OutputCommitCoordinatorActor(outputCommitCoordinator: OutputCommitCoordinator) + extends Actor with ActorLogReceive with Logging { + + override def receiveWithLogging = { + case StageStarted(stage) => + outputCommitCoordinator.handleStageStart(stage) + case StageEnded(stage) => + outputCommitCoordinator.handleStageEnd(stage) + case AskPermissionToCommitOutput(stage, task, taskAttempt) => + sender ! outputCommitCoordinator.handleAskPermissionToCommit(stage, task, taskAttempt) + case TaskCompleted(stage, task, attempt, reason) => + outputCommitCoordinator.handleTaskCompletion(stage, task, attempt, reason) + case StopCoordinator => + logInfo("OutputCommitCoordinator stopped!") + context.stop(self) + sender ! true + } + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index a1dfb0106259..d49792d774c9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -158,7 +158,7 @@ private[spark] class TaskSchedulerImpl( val tasks = taskSet.tasks logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks") this.synchronized { - val manager = new TaskSetManager(this, taskSet, maxTaskFailures) + val manager = createTaskSetManager(taskSet, maxTaskFailures) activeTaskSets(taskSet.id) = manager schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) @@ -180,6 +180,13 @@ private[spark] class TaskSchedulerImpl( backend.reviveOffers() } + // Label as private[scheduler] to allow tests to swap in different task set managers if necessary + private[scheduler] def createTaskSetManager( + taskSet: TaskSet, + maxTaskFailures: Int): TaskSetManager = { + new TaskSetManager(this, taskSet, maxTaskFailures) + } + override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized { logInfo("Cancelling stage " + stageId) activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 5c94c6bbcb37..afd40cb4bf1e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -292,7 +292,8 @@ private[spark] class TaskSetManager( * an attempt running on this host, in case the host is slow. In addition, the task should meet * the given locality constraint. */ - private def dequeueSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value) + // Labeled as protected to allow tests to override providing speculative tasks if necessary + protected def dequeueSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value) : Option[(Int, TaskLocality.Value)] = { speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set @@ -596,7 +597,9 @@ private[spark] class TaskSetManager( removeRunningTask(tid) info.markFailed() val index = info.index - copiesRunning(index) -= 1 + if (copiesRunning(index) >= 1) { + copiesRunning(index) -= 1 + } var taskMetrics : TaskMetrics = null val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}): " + @@ -646,8 +649,12 @@ private[spark] class TaskSetManager( s"${ef.className} (${ef.description}) [duplicate $dupCount]") } + case e: TaskCommitDenied => + logWarning(failureReason) + case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others logWarning(failureReason) + return case e: TaskEndReason => logError("Unknown TaskEndReason: " + e) @@ -656,17 +663,20 @@ private[spark] class TaskSetManager( failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()). put(info.executorId, clock.getTime()) sched.dagScheduler.taskEnded(tasks(index), reason, null, null, info, taskMetrics) - addPendingTask(index) - if (!isZombie && state != TaskState.KILLED) { - assert (null != failureReason) - numFailures(index) += 1 - if (numFailures(index) >= maxTaskFailures) { - logError("Task %d in stage %s failed %d times; aborting job".format( - index, taskSet.id, maxTaskFailures)) - abort("Task %d in stage %s failed %d times, most recent failure: %s\nDriver stacktrace:" - .format(index, taskSet.id, maxTaskFailures, failureReason)) - return + if (!reason.isInstanceOf[TaskCommitDenied]) { + addPendingTask(index) + if (!isZombie && state != TaskState.KILLED) { + assert (null != failureReason) + numFailures(index) += 1 + if (numFailures(index) >= maxTaskFailures) { + logError("Task %d in stage %s failed %d times; aborting job".format( + index, taskSet.id, maxTaskFailures)) + abort("Task %d in stage %s failed %d times, most recent failure: %s\nDriver stacktrace:" + .format(index, taskSet.id, maxTaskFailures, failureReason)) + return + } } + } maybeFinishTaskSet() } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSingleThreadedProcessLoop.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSingleThreadedProcessLoop.scala new file mode 100644 index 000000000000..3df2146a449f --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSingleThreadedProcessLoop.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import scala.util.control.NonFatal + +class DAGSchedulerSingleThreadedProcessLoop(dagScheduler: DAGScheduler) + extends DAGSchedulerEventProcessLoop(dagScheduler) { + + override def post(event: DAGSchedulerEvent): Unit = { + try { + // Forward event to `onReceive` directly to avoid processing event asynchronously. + onReceive(event) + } catch { + case NonFatal(e) => onError(e) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index eb116213f69f..6d8389c5db03 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -19,8 +19,8 @@ package org.apache.spark.scheduler import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap, Map} import scala.language.reflectiveCalls -import scala.util.control.NonFatal +import org.mockito.Mockito.mock import org.scalatest.{BeforeAndAfter, FunSuiteLike} import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ @@ -32,19 +32,6 @@ import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} import org.apache.spark.util.CallSite import org.apache.spark.executor.TaskMetrics -class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler) - extends DAGSchedulerEventProcessLoop(dagScheduler) { - - override def post(event: DAGSchedulerEvent): Unit = { - try { - // Forward event to `onReceive` directly to avoid processing event asynchronously. - onReceive(event) - } catch { - case NonFatal(e) => onError(e) - } - } -} - /** * An RDD for passing to DAGScheduler. These RDDs will use the dependencies and * preferredLocations (if any) that are passed to them. They are deliberately not executable @@ -171,7 +158,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar runLocallyWithinThread(job) } } - dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler) + dagEventProcessLoopTester = new DAGSchedulerSingleThreadedProcessLoop(scheduler) } override def afterAll() { @@ -208,7 +195,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(taskSet.tasks.size >= results.size) for ((result, i) <- results.zipWithIndex) { if (i < taskSet.tasks.size) { - runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, null, null, null)) + runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, null, createFakeTaskInfo(), null)) } } } @@ -219,7 +206,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar for ((result, i) <- results.zipWithIndex) { if (i < taskSet.tasks.size) { runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, - Map[Long, Any]((accumId, 1)), null, null)) + Map[Long, Any]((accumId, 1)), createFakeTaskInfo(), null)) } } } @@ -399,7 +386,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar runLocallyWithinThread(job) } } - dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(noKillScheduler) + dagEventProcessLoopTester = new DAGSchedulerSingleThreadedProcessLoop(noKillScheduler) val jobId = submit(new MyRDD(sc, 1, Nil), Array(0)) cancel(jobId) // Because the job wasn't actually cancelled, we shouldn't have received a failure message. @@ -476,7 +463,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null, Map[Long, Any](), - null, + createFakeTaskInfo(), null)) assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.failedStages.contains(1)) @@ -487,7 +474,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1, "ignored"), null, Map[Long, Any](), - null, + createFakeTaskInfo(), null)) // The SparkListener should not receive redundant failure events. assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) @@ -507,14 +494,14 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(newEpoch > oldEpoch) val taskSet = taskSets(0) // should be ignored for being too old - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, null, null)) + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) // should work because it's a non-failed host - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), null, null, null)) + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), null, createFakeTaskInfo(), null)) // should be ignored for being too old - runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, null, null)) + runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) // should work because it's a new epoch taskSet.tasks(1).epoch = newEpoch - runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), null, null, null)) + runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) complete(taskSets(1), Seq((Success, 42), (Success, 43))) @@ -766,5 +753,10 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar assert(scheduler.shuffleToMapStage.isEmpty) assert(scheduler.waitingStages.isEmpty) } + + // Nothing in this test should break if the task info's fields are null, but + // OutputCommitCoordinator requires the task info itself to not be null. + private def createFakeTaskInfo(): TaskInfo = mock(classOf[TaskInfo]) + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala new file mode 100644 index 000000000000..eed7f0991978 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.io.{File, ObjectInputStream, ObjectOutputStream, IOException} + +import org.mockito.Matchers +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.concurrent.Timeouts +import org.scalatest.{BeforeAndAfter, FunSuite} + +import org.apache.hadoop.mapred.{TaskAttemptID, JobConf, TaskAttemptContext, OutputCommitter} + +import org.apache.spark._ +import org.apache.spark.rdd.FakeOutputCommitter +import org.apache.spark.util.Utils + +import scala.collection.mutable.ArrayBuffer + +/** + * Unit tests for the output commit coordination functionality. + * + * The unit test makes both the original task and the speculated task + * attempt to commit, where committing is emulated by creating a + * directory. If both tasks create directories then the end result is + * a failure. + * + * Note that there are some aspects of this test that are less than ideal. + * In particular, the test mocks the speculation-dequeuing logic to always + * dequeue a task and consider it as speculated. Immediately after initially + * submitting the tasks and calling reviveOffers(), reviveOffers() is invoked + * again to pick up the speculated task. This may be hacking the original + * behavior in too much of an unrealistic fashion. + * + * Also, the validation is done by checking the number of files in a directory. + * Ideally, an accumulator would be used for this, where we could increment + * the accumulator in the output committer's commitTask() call. If the call to + * commitTask() was called twice erroneously then the test would ideally fail because + * the accumulator would be incremented twice. + * + * The problem with this test implementation is that when both a speculated task and + * its original counterpart complete, only one of the accumulator's increments is + * captured. This results in a paradox where if the OutputCommitCoordinator logic + * was not in SparkHadoopWriter, the tests would still pass because only one of the + * increments would be captured even though the commit in both tasks was executed + * erroneously. + */ +class OutputCommitCoordinatorSuite + extends FunSuite + with BeforeAndAfter + with Timeouts { + + val conf = new SparkConf() + .set("spark.localExecution.enabled", "true") + + var dagScheduler: DAGScheduler = null + var tempDir: File = null + var tempDirPath: String = null + var sc: SparkContext = null + + before { + sc = new SparkContext("local[4]", "Output Commit Coordinator Suite") + tempDir = Utils.createTempDir() + tempDirPath = tempDir.getAbsolutePath() + // Use Mockito.spy() to maintain the default infrastructure everywhere else + val mockTaskScheduler = spy(sc.taskScheduler.asInstanceOf[TaskSchedulerImpl]) + + doAnswer(new Answer[Unit]() { + override def answer(invoke: InvocationOnMock): Unit = { + // Submit the tasks, then, force the task scheduler to dequeue the + // speculated task + invoke.callRealMethod() + mockTaskScheduler.backend.reviveOffers() + } + }).when(mockTaskScheduler).submitTasks(Matchers.any()) + + doAnswer(new Answer[TaskSetManager]() { + override def answer(invoke: InvocationOnMock): TaskSetManager = { + val taskSet = invoke.getArguments()(0).asInstanceOf[TaskSet] + return new TaskSetManager(mockTaskScheduler, taskSet, 4) { + var hasDequeuedSpeculatedTask = false + override def dequeueSpeculativeTask( + execId: String, + host: String, + locality: TaskLocality.Value): Option[(Int, TaskLocality.Value)] = { + if (!hasDequeuedSpeculatedTask) { + hasDequeuedSpeculatedTask = true + return Some(0, TaskLocality.PROCESS_LOCAL) + } else { + return None + } + } + } + } + }).when(mockTaskScheduler).createTaskSetManager(Matchers.any(), Matchers.any()) + + sc.taskScheduler = mockTaskScheduler + val dagSchedulerWithMockTaskScheduler = new DAGScheduler(sc, mockTaskScheduler) + sc.taskScheduler.setDAGScheduler(dagSchedulerWithMockTaskScheduler) + sc.dagScheduler = dagSchedulerWithMockTaskScheduler + } + + after { + sc.stop() + tempDir.delete() + } + + /** + * Function that constructs a SparkHadoopWriter with a mock committer and runs its commit + */ + private class OutputCommittingFunction(private var tempDirPath: String) + extends ((TaskContext, Iterator[Int]) => Int) with Serializable { + + def apply(ctxt: TaskContext, it: Iterator[Int]): Int = { + val outputCommitter = new FakeOutputCommitter { + override def commitTask(context: TaskAttemptContext) : Unit = { + Utils.createDirectory(tempDirPath) + } + } + runCommitWithProvidedCommitter(ctxt, it, outputCommitter) + } + + protected def runCommitWithProvidedCommitter( + ctxt: TaskContext, + it: Iterator[Int], + outputCommitter: OutputCommitter): Int = { + def jobConf = new JobConf { + override def getOutputCommitter(): OutputCommitter = outputCommitter + } + val sparkHadoopWriter = new SparkHadoopWriter(jobConf) { + override def newTaskAttemptContext( + conf: JobConf, + attemptId: TaskAttemptID): TaskAttemptContext = { + mock(classOf[TaskAttemptContext]) + } + } + sparkHadoopWriter.setup(ctxt.stageId, ctxt.partitionId, ctxt.attemptNumber) + sparkHadoopWriter.commit + 0 + } + + // Need this otherwise the entire test suite attempts to be serialized + @throws(classOf[IOException]) + private def writeObject(out: ObjectOutputStream): Unit = { + out.writeUTF(tempDirPath) + } + + @throws(classOf[IOException]) + private def readObject(in: ObjectInputStream): Unit = { + tempDirPath = in.readUTF() + } + } + + /** + * Function that will explicitly fail to commit on the first attempt + */ + private class FailFirstTimeCommittingFunction(private var tempDirPath: String) + extends OutputCommittingFunction(tempDirPath) { + override def apply(ctxt: TaskContext, it: Iterator[Int]): Int = { + if (ctxt.attemptNumber == 0) { + val outputCommitter = new FakeOutputCommitter { + override def commitTask(taskAttemptContext: TaskAttemptContext) { + throw new RuntimeException + } + } + runCommitWithProvidedCommitter(ctxt, it, outputCommitter) + } else { + super.apply(ctxt, it) + } + } + } + + test("Only one of two duplicate commit tasks should commit") { + val rdd = sc.parallelize(Seq(1), 1) + sc.runJob(rdd, new OutputCommittingFunction(tempDirPath), + 0 until rdd.partitions.size, allowLocal = true) + assert(tempDir.list().size === 1) + } + + test("If commit fails, if task is retried it should not be locked, and will succeed.") { + val rdd = sc.parallelize(Seq(1), 1) + sc.runJob(rdd, new FailFirstTimeCommittingFunction(tempDirPath), + 0 until rdd.partitions.size, allowLocal = true) + assert(tempDir.list().size === 1) + } +}