diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 894091761485..cd2a7a657088 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -19,9 +19,12 @@ package org.apache.spark import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} +import java.util.concurrent._ +import java.util.Collections import scala.collection.mutable.{HashSet, HashMap, Map} import scala.concurrent.Await +import scala.collection.JavaConversions._ import akka.actor._ import akka.pattern.ask @@ -30,11 +33,13 @@ import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage.BlockManagerId import org.apache.spark.util._ +import scala.collection.mutable private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage +private[spark] case class GetShuffleStatus(shuffleId: Int) extends MapOutputTrackerMessage /** Actor class for MapOutputTrackerMaster */ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf) @@ -64,6 +69,9 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster logInfo("MapOutputTrackerActor stopped!") sender ! true context.stop(self) + + case GetShuffleStatus(shuffleId: Int) => + sender ! tracker.completenessForShuffle(shuffleId) } } @@ -87,6 +95,14 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging */ protected val mapStatuses: Map[Int, Array[MapStatus]] + // Track if we have partial map outputs for a shuffle + protected val partialForShuffle = + Collections.newSetFromMap[Int](new ConcurrentHashMap[Int, java.lang.Boolean]()) + + protected val partialEpoch = new mutable.HashMap[Int, Int]() + + protected val updaterLock = new ConcurrentHashMap[Int, AnyRef]() + /** * Incremented every time a fetch fails so that client nodes know to clear * their cache of map output locations if this happens. @@ -126,6 +142,45 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * a given shuffle. */ def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = { + val statuses = getMapStatusesForShuffle(shuffleId, reduceId) + statuses.synchronized { + MapOutputTracker.convertMapStatuses( + shuffleId, reduceId, statuses, isPartial = partialForShuffle.contains(shuffleId)) + } + } + + /** Called to get current epoch number. */ + def getEpoch: Long = { + epochLock.synchronized { + return epoch + } + } + + /** + * Called from executors to update the epoch number, potentially clearing old outputs + * because of a fetch failure. Each worker task calls this with the latest epoch + * number on the master at the time it was created. + */ + def updateEpoch(newEpoch: Long) { + epochLock.synchronized { + if (newEpoch > epoch) { + logInfo("Updating epoch from "+epoch+" to " + newEpoch + " and clearing cache") + epoch = newEpoch + mapStatuses.clear() + } + } + } + + /** Unregister shuffle data. */ + def unregisterShuffle(shuffleId: Int) { + mapStatuses.remove(shuffleId) + } + + /** Stop the tracker. */ + def stop() { } + + // Get map statuses for a shuffle + private def getMapStatusesForShuffle(shuffleId: Int, reduceId: Int): Array[MapStatus]={ val statuses = mapStatuses.get(shuffleId).orNull if (statuses == null) { logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") @@ -158,7 +213,15 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging try { val fetchedBytes = askTracker(GetMapOutputStatuses(shuffleId)).asInstanceOf[Array[Byte]] - fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) + val fetchedResults = MapOutputTracker.deserializeMapStatuses(fetchedBytes) + fetchedStatuses = fetchedResults._1 + if (fetchedResults._2) { + if(partialForShuffle.add(shuffleId)){ + new Thread(new MapStatusUpdater(shuffleId)).start() + } + } else { + partialForShuffle -= shuffleId + } logInfo("Got the output locations") mapStatuses.put(shuffleId, fetchedStatuses) } finally { @@ -169,49 +232,92 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } if (fetchedStatuses != null) { - fetchedStatuses.synchronized { - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) - } + fetchedStatuses } else { throw new MetadataFetchFailedException( shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId) } } else { - statuses.synchronized { - return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) - } + statuses } } - /** Called to get current epoch number. */ - def getEpoch: Long = { - epochLock.synchronized { - return epoch + // Clear outdated map outputs for a shuffle + private def clearOutdatedMapStatuses(shuffleId: Int): Boolean = { + if (mapStatuses.contains(shuffleId)) { + val masterCompleteness = askTracker(GetShuffleStatus(shuffleId)).asInstanceOf[Int] + val diff = masterCompleteness - completenessForShuffle(shuffleId) + if (diff > 0) { + logInfo("Master is " + diff + " map statuses ahead of us for shuffle " + + shuffleId + ". Clear local cache.") + mapStatuses -= shuffleId + return true + } else { + return false + } } + true } - /** - * Called from executors to update the epoch number, potentially clearing old outputs - * because of a fetch failure. Each worker task calls this with the latest epoch - * number on the master at the time it was created. - */ - def updateEpoch(newEpoch: Long) { - epochLock.synchronized { - if (newEpoch > epoch) { - logInfo("Updating epoch to " + newEpoch + " and clearing cache") - epoch = newEpoch - mapStatuses.clear() + // Compute the completeness of map statuses for a shuffle + def completenessForShuffle(shuffleId: Int): Int = { + mapStatuses.getOrElse(shuffleId, new Array[MapStatus](0)).count(_ != null) + } + + // A proxy to update partial map statuses periodically + class MapStatusUpdater(shuffleId: Int) extends Runnable { + override def run() { + updaterLock.put(shuffleId, new AnyRef) + partialEpoch.synchronized { + if (!partialEpoch.contains(shuffleId)) { + partialEpoch.put(shuffleId, 0) + } + } + logInfo("Updater started for shuffle " + shuffleId + ".") + val minInterval = 1000 + val maxInterval = 3000 + var lastUpdate = System.currentTimeMillis() + while (partialForShuffle.contains(shuffleId)) { + updaterLock.getOrElseUpdate(shuffleId, new AnyRef).synchronized { + updaterLock(shuffleId).wait(maxInterval) + } + val interval = System.currentTimeMillis() - lastUpdate + if (interval < minInterval) { + Thread.sleep(minInterval - interval) + } + lastUpdate = System.currentTimeMillis() + if (clearOutdatedMapStatuses(shuffleId)) { + getMapStatusesForShuffle(shuffleId, -1) + partialEpoch.synchronized { + partialEpoch.put(shuffleId, partialEpoch.getOrElse(shuffleId, 0) + 1) + partialEpoch.notifyAll() + } + } + } + logInfo("Map status for shuffle " + shuffleId + " is now complete. Updater terminated.") + partialEpoch.synchronized { + partialEpoch.remove(shuffleId) + partialEpoch.notifyAll() } } } - /** Unregister shuffle data. */ - def unregisterShuffle(shuffleId: Int) { - mapStatuses.remove(shuffleId) + def getUpdatedStatus( + shuffleId: Int, reduceId: Int, localEpoch: Int): (Array[(BlockManagerId, Long)], Int) = { + partialEpoch.synchronized { + if (!partialEpoch.contains(shuffleId)) { + return (getServerStatuses(shuffleId, reduceId), 0) + } + if (partialEpoch.get(shuffleId).get <= localEpoch) { + updaterLock.getOrElseUpdate(shuffleId, new AnyRef).synchronized { + updaterLock(shuffleId).notifyAll() + } + logInfo("Reduce "+reduceId+" waiting for map outputs of shuffle "+shuffleId+".") + partialEpoch.wait() + } + (getServerStatuses(shuffleId, reduceId), partialEpoch.getOrElse(shuffleId, 0)) + } } - - /** Stop the tracker. */ - def stop() { } } /** @@ -229,7 +335,8 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) * so that statuses are dropped only by explicit de-registering or by TTL-based cleaning (if set). * Other than these two scenarios, nothing should be dropped from this HashMap. */ - protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]() + protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]() with + mutable.SynchronizedMap[Int, Array[MapStatus]] private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]() // For cleaning up TimeStampedHashMaps @@ -240,6 +347,8 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } + // We allow partial output by default. Should be later properly set when register map outputs + partialForShuffle += shuffleId } def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { @@ -250,11 +359,18 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) } /** Register multiple map output information for the given shuffle */ - def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) { + def registerMapOutputs( + shuffleId: Int, statuses: Array[MapStatus], + changeEpoch: Boolean = false, isPartial: Boolean = false) { mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses) if (changeEpoch) { incrementEpoch() } + if (isPartial) { + partialForShuffle += shuffleId + } else { + partialForShuffle -= shuffleId + } } /** Unregister map output information of the given shuffle, mapper and block manager */ @@ -268,6 +384,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) } } incrementEpoch() + partialForShuffle += shuffleId } else { throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID") } @@ -275,7 +392,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) /** Unregister shuffle data */ override def unregisterShuffle(shuffleId: Int) { - mapStatuses.remove(shuffleId) + super.unregisterShuffle(shuffleId) cachedSerializedStatuses.remove(shuffleId) } @@ -309,11 +426,13 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) } // If we got here, we failed to find the serialized locations in the cache, so we pulled // out a snapshot of the locations as "statuses"; let's serialize and return that - val bytes = MapOutputTracker.serializeMapStatuses(statuses) + val partial = partialForShuffle.contains(shuffleId) + val bytes = MapOutputTracker.serializeMapStatuses(statuses,isPartial = partial) logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length)) // Add them into the table only if the epoch hasn't changed while we were working epochLock.synchronized { - if (epoch == epochGotten) { + // Don't put partial outputs in cache + if (epoch == epochGotten && !partial) { cachedSerializedStatuses(shuffleId) = bytes } } @@ -340,6 +459,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) */ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) { protected val mapStatuses = new HashMap[Int, Array[MapStatus]] + with mutable.SynchronizedMap[Int, Array[MapStatus]] } private[spark] object MapOutputTracker { @@ -348,21 +468,24 @@ private[spark] object MapOutputTracker { // Serialize an array of map output locations into an efficient byte format so that we can send // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will // generally be pretty compressible because many map outputs will be on the same hostname. - def serializeMapStatuses(statuses: Array[MapStatus]): Array[Byte] = { + def serializeMapStatuses(statuses: Array[MapStatus], isPartial: Boolean = false): Array[Byte] = { val out = new ByteArrayOutputStream val objOut = new ObjectOutputStream(new GZIPOutputStream(out)) // Since statuses can be modified in parallel, sync on it statuses.synchronized { objOut.writeObject(statuses) + objOut.writeBoolean(isPartial) } objOut.close() out.toByteArray } // Opposite of serializeMapStatuses. - def deserializeMapStatuses(bytes: Array[Byte]): Array[MapStatus] = { + def deserializeMapStatuses(bytes: Array[Byte]): (Array[MapStatus], Boolean) = { val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes))) - objIn.readObject().asInstanceOf[Array[MapStatus]] + val mapStatuses = objIn.readObject().asInstanceOf[Array[MapStatus]] + val isPartial = objIn.readBoolean() + (mapStatuses, isPartial) } // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If @@ -371,13 +494,18 @@ private[spark] object MapOutputTracker { private def convertMapStatuses( shuffleId: Int, reduceId: Int, - statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = { + statuses: Array[MapStatus], + isPartial: Boolean = false): Array[(BlockManagerId, Long)] = { assert (statuses != null) statuses.map { status => if (status == null) { - throw new MetadataFetchFailedException( - shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId) + if(isPartial){ + (null, 0.toLong) + } else { + throw new MetadataFetchFailedException( + shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId) + } } else { (status.location, decompressSize(status.compressedSizes(reduceId))) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 81c136d97031..2726231ed271 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -38,7 +38,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerMaster, RDDBlockId} -import org.apache.spark.util.{CallSite, SystemClock, Clock, Utils} +import org.apache.spark.util.{AkkaUtils, CallSite, SystemClock, Clock, Utils} /** * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of @@ -121,6 +121,11 @@ class DAGScheduler( private[scheduler] var eventProcessActor: ActorRef = _ + // Whether to enable remove stage barrier + val removeStageBarrier = env.conf.getBoolean("spark.scheduler.removeStageBarrier", false) + // Track the pre-started stages depending on a stage (the key) + private val dependantStagePreStarted = new HashMap[Stage, ArrayBuffer[Stage]]() + private def initializeEventProcessActor() { // blocking the thread until supervisor is started, which ensures eventProcessActor is // not null before any job is submitted @@ -243,7 +248,7 @@ class DAGScheduler( val stage = newStage(rdd, numTasks, Some(shuffleDep), jobId, callSite) if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) - val locs = MapOutputTracker.deserializeMapStatuses(serLocs) + val locs = MapOutputTracker.deserializeMapStatuses(serLocs)._1 for (i <- 0 until locs.size) { stage.outputLocs(i) = Option(locs(i)).toList // locs(i) will be null if missing } @@ -853,6 +858,13 @@ class DAGScheduler( logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId) } else { stage.addOutputLoc(smt.partitionId, status) + // Need to register map outputs progressively if remove stage barrier is enabled + if (removeStageBarrier && dependantStagePreStarted.contains(stage) && + stage.shuffleDep.isDefined) { + mapOutputTracker.registerMapOutputs(stage.shuffleDep.get.shuffleId, + stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray, + changeEpoch = false, isPartial = true) + } } if (runningStages.contains(stage) && pendingTasks(stage).isEmpty) { markStageAsFinished(stage) @@ -879,6 +891,17 @@ class DAGScheduler( logInfo("Resubmitting " + stage + " (" + stage.name + ") because some of its tasks had failed: " + stage.outputLocs.zipWithIndex.filter(_._1 == Nil).map(_._2).mkString(", ")) + // Pre-started dependant stages should fail + val stages = new ArrayBuffer[Stage]() + if (dependantStagePreStarted.contains(stage)) { + for (preStartedStage <- dependantStagePreStarted.get(stage).get) { + logInfo("Marking " + preStartedStage + " (" + preStartedStage.name + + ") for resubmision due to parent stage resubmission") + runningStages -= preStartedStage + stages += preStartedStage + } + } + failStages(stages.toArray) submitStage(stage) } else { val newlyRunnable = new ArrayBuffer[Stage] @@ -898,6 +921,9 @@ class DAGScheduler( submitMissingTasks(stage, jobId) } } + dependantStagePreStarted -= stage + } else { + maybePreStartWaitingStage(stage) } } @@ -908,31 +934,26 @@ class DAGScheduler( case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => // Mark the stage that the reducer was in as unrunnable val failedStage = stageIdToStage(task.stageId) - runningStages -= failedStage - // TODO: Cancel running tasks in the stage - logInfo("Marking " + failedStage + " (" + failedStage.name + - ") for resubmision due to a fetch failure") - // Mark the map whose fetch failed as broken in the map stage - val mapStage = shuffleToMapStage(shuffleId) - if (mapId != -1) { - mapStage.removeOutputLoc(mapId, bmAddress) - mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) - } - logInfo("The failed fetch was from " + mapStage + " (" + mapStage.name + - "); marking it for resubmission") - if (failedStages.isEmpty && eventProcessActor != null) { - // Don't schedule an event to resubmit failed stages if failed isn't empty, because - // in that case the event will already have been scheduled. eventProcessActor may be - // null during unit tests. - import env.actorSystem.dispatcher - env.actorSystem.scheduler.scheduleOnce( - RESUBMIT_TIMEOUT, eventProcessActor, ResubmitFailedStages) - } - failedStages += failedStage - failedStages += mapStage - // TODO: mark the executor as failed only if there were lots of fetch failures on it - if (bmAddress != null) { - handleExecutorLost(bmAddress.executorId, Some(task.epoch)) + if(runningStages.remove(failedStage)){ + val stages = new ArrayBuffer[Stage]() + stages += failedStage + logInfo("Marking " + failedStage + " (" + failedStage.name + + ") for resubmision due to a fetch failure") + // Mark the map whose fetch failed as broken in the map stage + val mapStage = shuffleToMapStage(shuffleId) + if (mapId != -1) { + mapStage.removeOutputLoc(mapId, bmAddress) + mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) + } + runningStages -= mapStage + stages += mapStage + logInfo("The failed fetch was from " + mapStage + " (" + mapStage.name + + "); marking it for resubmission") + failStages(stages.toArray) + // TODO: mark the executor as failed only if there were lots of fetch failures on it + if (bmAddress != null) { + handleExecutorLost(bmAddress.executorId, Some(task.epoch)) + } } case ExceptionFailure(className, description, stackTrace, metrics) => @@ -965,7 +986,7 @@ class DAGScheduler( for ((shuffleId, stage) <- shuffleToMapStage) { stage.removeOutputsOnExecutor(execId) val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray - mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true) + mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true, isPartial = true) } if (shuffleToMapStage.isEmpty) { mapOutputTracker.incrementEpoch() @@ -1154,6 +1175,91 @@ class DAGScheduler( dagSchedulerActorSupervisor ! PoisonPill taskScheduler.stop() } + + // Select a waiting stage to pre-start + private def getPreStartableStage(stage: Stage): Option[Stage] = { + for (waitingStage <- waitingStages) { + val missingParents = getMissingParentStages(waitingStage) + if (missingParents.contains(stage) && + missingParents.forall( + parent => !(waitingStages.contains(parent) || failedStages.contains(parent)))) { + return Some(waitingStage) + } + } + None + } + + // Check if the given stageId is a pre-started stage + private[scheduler] def handleCheckIfPreStarted(stageId: Int): Boolean = { + if (stageIdToStage.contains(stageId)) { + val stage = stageIdToStage(stageId) + for (preStartedStages <- dependantStagePreStarted.values) { + if (preStartedStages.contains(stage)) { + return true + } + } + } + false + } + + def isPreStartStage(stageId: Int): Boolean = { + if (!removeStageBarrier) { + return false + } + try { + val timeout = AkkaUtils.askTimeout(sc.conf) + val future = eventProcessActor.ask(CheckIfPreStarted(stageId))(timeout) + Await.result(future, timeout).asInstanceOf[Boolean] + } catch { + case e: Exception => + throw new SparkException("Time out asking event processor.", e) + } + } + + // Mark some stages as failed and resubmit them + private def failStages(stages: Array[Stage]) { + // Let's first kill all the running tasks in the failed stage + for (failedStage <- stages) { + taskScheduler.killTasks(failedStage.id, false) + } + if (failedStages.isEmpty && eventProcessActor != null) { + // Don't schedule an event to resubmit failed stages if failed isn't empty, because + // in that case the event will already have been scheduled. eventProcessActor may be + // null during unit tests. + import env.actorSystem.dispatcher + env.actorSystem.scheduler.scheduleOnce( + RESUBMIT_TIMEOUT, eventProcessActor, ResubmitFailedStages) + } + failedStages ++= stages + } + + private def maybePreStartWaitingStage(stage: Stage) { + if (removeStageBarrier && taskScheduler.isInstanceOf[TaskSchedulerImpl]) { + // TODO: need a better way to check if there's free slots + val backend = taskScheduler.asInstanceOf[TaskSchedulerImpl].backend + val numPendingTask = pendingTasks.values.map(_.size).sum + val numWaitingStage = waitingStages.size + if (backend.freeSlotAvail(numPendingTask) && numWaitingStage > 0 && + stage.shuffleDep.isDefined) { + for (preStartStage <- getPreStartableStage(stage)) { + logInfo("Pre-start stage " + preStartStage.id) + // Register map output finished so far + mapOutputTracker.registerMapOutputs(stage.shuffleDep.get.shuffleId, + stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray, + changeEpoch = false, isPartial = true) + waitingStages -= preStartStage + runningStages += preStartStage + // Inform parent stages that the dependant stage has been pre-started + for (parentStage <- getMissingParentStages(preStartStage) + if runningStages.contains(parentStage)) { + dependantStagePreStarted.getOrElseUpdate( + parentStage, new ArrayBuffer[Stage]()) += preStartStage + } + submitMissingTasks(preStartStage, activeJobForStage(preStartStage).get) + } + } + } + } } private[scheduler] class DAGSchedulerActorSupervisor(dagScheduler: DAGScheduler) @@ -1227,6 +1333,9 @@ private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGSchedule case ResubmitFailedStages => dagScheduler.resubmitFailedStages() + + case CheckIfPreStarted(stageId) => + sender ! dagScheduler.handleCheckIfPreStarted(stageId) } override def postStop() { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 2b6f7e4205c3..0115a4221d45 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -77,3 +77,5 @@ private[scheduler] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent private[scheduler] case object ResubmitFailedStages extends DAGSchedulerEvent + +private[scheduler] case class CheckIfPreStarted(stageId: Int) extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 174b73221afc..3ccb1597b1dc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -102,7 +102,8 @@ private[spark] class Pool( for (schedulable <- sortedSchedulableQueue) { sortedTaskSetQueue ++= schedulable.getSortedTaskSetQueue } - sortedTaskSetQueue + val partitionedTaskSets = sortedTaskSetQueue.partition(!_.isPreStart()) + partitionedTaskSets._1 ++ partitionedTaskSets._2 } def increaseRunningTasks(taskNum: Int) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index 6a6d8e609bc3..449f88174578 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -30,4 +30,6 @@ private[spark] trait SchedulerBackend { def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = throw new UnsupportedOperationException + + def freeSlotAvail(numPendingTask: Int): Boolean = false } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index df59f444b7a0..f4174a354284 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -105,4 +105,24 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul def stop() { getTaskResultExecutor.shutdownNow() } + + def enqueueFailedTaskSync(taskSetManager: TaskSetManager, tid: Long, taskState: TaskState, + serializedData: ByteBuffer) { + var reason : TaskEndReason = UnknownReason + try { + if (serializedData != null && serializedData.limit() > 0) { + reason = serializer.get().deserialize[TaskEndReason]( + serializedData, Utils.getSparkClassLoader) + } + } catch { + case cnd: ClassNotFoundException => + // Log an error but keep going here -- the task failed, so not catastropic if we can't + // deserialize the reason. + val loader = Utils.getContextOrSparkClassLoader + logError( + "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader) + case ex: Exception => {} + } + scheduler.handleFailedTask(taskSetManager, tid, taskState, reason) + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 819c35257b5a..b8a2da4ef0e4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -54,4 +54,8 @@ private[spark] trait TaskScheduler { // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs. def defaultParallelism(): Int + + // Kill the running tasks in a stage, without aborting the job + // This is used for resubmitting a failed stage + def killTasks(stageId: Int, interruptThread: Boolean): Unit = {} } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 5ed2803d76af..74fe9531fadc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -206,65 +206,67 @@ private[spark] class TaskSchedulerImpl( * sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so * that tasks are balanced across the cluster. */ - def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized { - SparkEnv.set(sc.env) - - // Mark each slave as alive and remember its hostname - // Also track if new executor is added - var newExecAvail = false - for (o <- offers) { - executorIdToHost(o.executorId) = o.host - if (!executorsByHost.contains(o.host)) { - executorsByHost(o.host) = new HashSet[String]() - executorAdded(o.executorId, o.host) - newExecAvail = true + def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = { + val sortedTaskSets = rootPool.getSortedTaskSetQueue + this.synchronized { + SparkEnv.set(sc.env) + + // Mark each slave as alive and remember its hostname + // Also track if new executor is added + var newExecAvail = false + for (o <- offers) { + executorIdToHost(o.executorId) = o.host + if (!executorsByHost.contains(o.host)) { + executorsByHost(o.host) = new HashSet[String]() + executorAdded(o.executorId, o.host) + newExecAvail = true + } } - } - // Randomly shuffle offers to avoid always placing tasks on the same set of workers. - val shuffledOffers = Random.shuffle(offers) - // Build a list of tasks to assign to each worker. - val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores)) - val availableCpus = shuffledOffers.map(o => o.cores).toArray - val sortedTaskSets = rootPool.getSortedTaskSetQueue - for (taskSet <- sortedTaskSets) { - logDebug("parentName: %s, name: %s, runningTasks: %s".format( - taskSet.parent.name, taskSet.name, taskSet.runningTasks)) - if (newExecAvail) { - taskSet.executorAdded() + // Randomly shuffle offers to avoid always placing tasks on the same set of workers. + val shuffledOffers = Random.shuffle(offers) + // Build a list of tasks to assign to each worker. + val tasks = shuffledOffers.map(o => new ArrayBuffer[TaskDescription](o.cores)) + val availableCpus = shuffledOffers.map(o => o.cores).toArray + for (taskSet <- sortedTaskSets) { + logDebug("parentName: %s, name: %s, runningTasks: %s".format( + taskSet.parent.name, taskSet.name, taskSet.runningTasks)) + if (newExecAvail) { + taskSet.executorAdded() + } } - } - // Take each TaskSet in our scheduling order, and then offer it each node in increasing order - // of locality levels so that it gets a chance to launch local tasks on all of them. - var launchedTask = false - for (taskSet <- sortedTaskSets; maxLocality <- taskSet.myLocalityLevels) { - do { - launchedTask = false - for (i <- 0 until shuffledOffers.size) { - val execId = shuffledOffers(i).executorId - val host = shuffledOffers(i).host - if (availableCpus(i) >= CPUS_PER_TASK) { - for (task <- taskSet.resourceOffer(execId, host, maxLocality)) { - tasks(i) += task - val tid = task.taskId - taskIdToTaskSetId(tid) = taskSet.taskSet.id - taskIdToExecutorId(tid) = execId - activeExecutorIds += execId - executorsByHost(host) += execId - availableCpus(i) -= CPUS_PER_TASK - assert (availableCpus(i) >= 0) - launchedTask = true + // Take each TaskSet in our scheduling order, and then offer it each node in increasing order + // of locality levels so that it gets a chance to launch local tasks on all of them. + var launchedTask = false + for (taskSet <- sortedTaskSets; maxLocality <- taskSet.myLocalityLevels) { + do { + launchedTask = false + for (i <- 0 until shuffledOffers.size) { + val execId = shuffledOffers(i).executorId + val host = shuffledOffers(i).host + if (availableCpus(i) >= CPUS_PER_TASK) { + for (task <- taskSet.resourceOffer(execId, host, maxLocality)) { + tasks(i) += task + val tid = task.taskId + taskIdToTaskSetId(tid) = taskSet.taskSet.id + taskIdToExecutorId(tid) = execId + activeExecutorIds += execId + executorsByHost(host) += execId + availableCpus(i) -= CPUS_PER_TASK + assert (availableCpus(i) >= 0) + launchedTask = true + } } } - } - } while (launchedTask) - } + } while (launchedTask) + } - if (tasks.size > 0) { - hasLaunchedTask = true + if (tasks.size > 0) { + hasLaunchedTask = true + } + return tasks } - return tasks } def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) { @@ -291,7 +293,7 @@ private[spark] class TaskSchedulerImpl( taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData) } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) { taskSet.removeRunningTask(tid) - taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData) + taskResultGetter.enqueueFailedTaskSync(taskSet, tid, state, serializedData) } } case None => @@ -328,11 +330,12 @@ private[spark] class TaskSchedulerImpl( taskState: TaskState, reason: TaskEndReason) = synchronized { taskSetManager.handleFailedTask(tid, taskState, reason) - if (!taskSetManager.isZombie && taskState != TaskState.KILLED) { + //if (!taskSetManager.isZombie && taskState != TaskState.KILLED) { // Need to revive offers again now that the task set manager state has been updated to // reflect failed tasks that need to be re-run. - backend.reviveOffers() - } + //backend.reviveOffers() + //} + backend.reviveOffers() } def error(message: String) { @@ -437,6 +440,17 @@ private[spark] class TaskSchedulerImpl( // By default, rack is unknown def getRackForHost(value: String): Option[String] = None + + override def killTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized { + logInfo("Killing tasks in stage " + stageId) + activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) => + tsm.runningTasksSet.foreach { tid => + val execId = taskIdToExecutorId(tid) + backend.killTask(tid, execId, interruptThread) + } + tsm.kill() + } + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 059cc9085a2e..a95ca26ac811 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -580,6 +580,7 @@ private[spark] class TaskSetManager( case TaskResultLost => failureReason = "Lost result for TID %s on host %s".format(tid, info.host) logWarning(failureReason) + // TODO: may cause some sort of "deadlock" if we lost the reuslt of a shuffle map task case _ => failureReason = "TID %s on host %s failed for unknown reason".format(tid, info.host) @@ -636,9 +637,7 @@ private[spark] class TaskSetManager( override def removeSchedulable(schedulable: Schedulable) {} override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = { - var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]() - sortedTaskSetQueue += this - sortedTaskSetQueue + ArrayBuffer[TaskSetManager](this) } /** Called by TaskScheduler when an executor is lost so we can re-enqueue our tasks */ @@ -756,6 +755,16 @@ private[spark] class TaskSetManager( levels.toArray } + // Test if this stage is in pre-start state + def isPreStart() = sched.dagScheduler.isPreStartStage(stageId) + + // Kill this task set manager + def kill() { + isZombie = true + runningTasksSet.clear() + maybeFinishTaskSet() + } + // Re-compute pendingTasksWithNoPrefs since new preferred locations may become available def executorAdded() { def newLocAvail(index: Int): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 05d01b0c821f..d820de763741 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -87,16 +87,22 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A } case StatusUpdate(executorId, taskId, state, data) => - scheduler.statusUpdate(taskId, state, data.value) if (TaskState.isFinished(state)) { if (executorActor.contains(executorId)) { freeCores(executorId) += scheduler.CPUS_PER_TASK - makeOffers(executorId) + // Make offer right away if the task succeeds, otherwise wait for the scheduler to + // revive offer + // TODO: there may still be deadlock if a task reports success + // but we failed retrieving its result + if (state == TaskState.FINISHED) { + makeOffers(executorId) + } } else { // Ignoring the update since we don't know about the executor. val msg = "Ignored task status update (%d state %s) from unknown executor %s with ID %s" logWarning(msg.format(taskId, state, sender, executorId)) } + scheduler.statusUpdate(taskId, state, data.value) } case ReviveOffers => @@ -247,6 +253,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A throw new SparkException("Error notifying standalone scheduler's driver actor", e) } } + + override def freeSlotAvail(numPendingTask: Int): Boolean = { + numPendingTask * scheduler.CPUS_PER_TASK < totalCoreCount.get() + } } private[spark] object CoarseGrainedSchedulerBackend { diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index a932455776e3..4b16746705b3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.HashMap import org.apache.spark._ import org.apache.spark.executor.ShuffleReadMetrics import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.shuffle.{MetadataFetchFailedException, FetchFailedException} import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util.CompletionIterator @@ -37,20 +37,29 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { { logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) val blockManager = SparkEnv.get.blockManager + val mapOutputTracker = SparkEnv.get.mapOutputTracker val startTime = System.currentTimeMillis - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId) + val statuses = mapOutputTracker.getServerStatuses(shuffleId, reduceId) logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format( shuffleId, reduceId, System.currentTimeMillis - startTime)) - val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]] - for (((address, size), index) <- statuses.zipWithIndex) { - splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size)) - } - - val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map { - case (address, splits) => - (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) + val blockFetcherItr = if (statuses.exists(_._1 == null)) { + if (!blockManager.conf.getBoolean("spark.scheduler.removeStageBarrier", false)) { + throw new MetadataFetchFailedException( + shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId) + } + blockManager.getPartial(statuses, mapOutputTracker, serializer, shuffleId, reduceId) + } else { + val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]] + for (((address, size), index) <- statuses.zipWithIndex) { + splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size)) + } + val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map { + case (address, splits) => + (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) + } + blockManager.getMultiple(blocksByAddress, serializer) } def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = { @@ -73,7 +82,6 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { } } - val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer) val itr = blockFetcherItr.flatMap(unpackBlock) val completionIter = CompletionIterator[T, Iterator[T]](itr, { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index 408a79708805..4e92f5d94c24 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -19,13 +19,11 @@ package org.apache.spark.storage import java.util.concurrent.LinkedBlockingQueue -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashSet -import scala.collection.mutable.Queue +import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet, Queue} import io.netty.buffer.ByteBuf -import org.apache.spark.{Logging, SparkException} +import org.apache.spark.{MapOutputTracker, Logging, SparkException} import org.apache.spark.network.BufferMessage import org.apache.spark.network.ConnectionManagerId import org.apache.spark.network.netty.ShuffleCopier @@ -344,4 +342,96 @@ object BlockFetcherIterator { } } // End of NettyBlockFetcherIterator + + class PartialBlockFetcherIterator( + private val blockManager: BlockManager, + private var statuses: Array[(BlockManagerId, Long)], + private val mapOutputTracker: MapOutputTracker, + serializer: Serializer, + shuffleId: Int, + reduceId: Int) + extends BlockFetcherIterator { + private val iterators=new ArrayBuffer[BlockFetcherIterator]() + + // Track the map outputs we've delegated + private val delegatedStatuses = new HashSet[Int]() + + private var localEpoch = 0 + + // Check if the map output is partial + private def isPartial = statuses.exists(_._1 == null) + + // Get the updated map output + private def updateStatuses() { + val update = mapOutputTracker.getUpdatedStatus(shuffleId, reduceId, localEpoch) + statuses = update._1 + localEpoch = update._2 + } + + private def readyStatuses = (0 until statuses.size).filter(statuses(_)._1 != null) + + // Check if there's new map outputs available + private def newStatusesReady = readyStatuses.exists(!delegatedStatuses.contains(_)) + + private def getIterator() = { + while (!newStatusesReady) { + logInfo("Still missing " + statuses.filter(_._1 == null).size + + " map outputs for reduce " + reduceId + " of shuffle " + shuffleId) + updateStatuses() + } + val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]] + for (index <- readyStatuses if !delegatedStatuses.contains(index)) { + splitsByAddress.getOrElseUpdate(statuses(index)._1, ArrayBuffer()) += ((index, statuses(index)._2)) + delegatedStatuses += index + } + val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map { + case (address, splits) => + (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) + } + logInfo("Delegating " + blocksByAddress.map(_._2.size).sum + + " blocks to a new iterator for reduce " + reduceId + " of shuffle " + shuffleId) + blockManager.getMultiple(blocksByAddress, serializer) + } + + override def initialize(){ + iterators += getIterator() + } + + override def hasNext: Boolean = { + // Firstly see if the delegated iterators have more blocks for us + if (iterators.exists(_.hasNext)) { + return true + } + // If we have blocks not delegated yet, try to delegate them to a new iterator + // and depend on the iterator to tell us if there are valid blocks. + while (delegatedStatuses.size < statuses.size) { + val newItr = getIterator() + iterators += newItr + if (newItr.hasNext) { + return true + } + } + false + } + + override def next(): (BlockId, Option[Iterator[Any]]) = { + // Try to get a block from the iterators we've created + for (itr <- iterators if itr.hasNext) { + return itr.next() + } + // We rely on the iterators for "hasNext", shouldn't get here + throw new SparkException("No more blocks to fetch for reduceId " + reduceId) + } + + override def totalBlocks = iterators.map(_.totalBlocks).sum + + override def numLocalBlocks = iterators.map(_.numLocalBlocks).sum + + override def numRemoteBlocks = iterators.map(_.numRemoteBlocks).sum + + override def fetchWaitTime = iterators.map(_.fetchWaitTime).sum + + override def remoteBytesRead = iterators.map(_.remoteBytesRead).sum + } + // End of PartialBlockFetcherIterator } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 0db0a5bc7341..2be3bdd945b3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -561,6 +561,18 @@ private[spark] class BlockManager( iter } + def getPartial( + statuses: Array[(BlockManagerId, Long)], + mapOutputTracker: MapOutputTracker, + serializer: Serializer, + shuffleId: Int, + reduceId: Int) = { + val iter = new BlockFetcherIterator.PartialBlockFetcherIterator( + this, statuses, mapOutputTracker, serializer, shuffleId, reduceId) + iter.initialize() + iter + } + def put( blockId: BlockId, values: Iterator[Any], diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 71f66c826c5b..f4f8a0b21f7e 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -153,7 +153,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long) * size must also be passed by the caller. * * Lock on the object putLock to ensure that all the put requests and its associated block - * dropping is done by only on thread at a time. Otherwise while one thread is dropping + * dropping is done by only one thread at a time. Otherwise while one thread is dropping * blocks to free memory for one block, another thread may use up the freed space for * another block. * diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 35910e552fe8..848c4f896c0e 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -123,7 +123,7 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { override def releaseWriters(success: Boolean) { if (consolidateShuffleFiles) { if (success) { - val offsets = writers.map(_.fileSegment().offset) + val offsets = writers.map(writer => writer.fileSegment().offset + writer.bytesWritten) fileGroup.recordMapOutput(mapId, offsets) } recycleFileGroup(fileGroup) @@ -143,7 +143,7 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { val filename = physicalFileName(shuffleId, bucketId, fileId) blockManager.diskBlockManager.getFile(filename) } - val fileGroup = new ShuffleFileGroup(fileId, shuffleId, files) + val fileGroup = new ShuffleFileGroup(shuffleId, fileId, files) shuffleState.allFileGroups.add(fileGroup) fileGroup } @@ -231,9 +231,13 @@ object ShuffleBlockManager { * This ordering allows us to compute block lengths by examining the following block offset. * Note: mapIdToIndex(mapId) returns the index of the mapper into the vector for every * reducer. + * We also keep the offset of "one past the end" block, which is effectively the file length. + * Therefore when append new offsets, we're actually appending new file lengths */ private val blockOffsetsByReducer = Array.fill[PrimitiveVector[Long]](files.length) { - new PrimitiveVector[Long]() + val offsets = new PrimitiveVector[Long]() + offsets += 0 + offsets } def numBlocks = mapIdToIndex.size @@ -253,13 +257,16 @@ object ShuffleBlockManager { val blockOffsets = blockOffsetsByReducer(reducerId) val index = mapIdToIndex.getOrElse(mapId, -1) if (index >= 0) { + assert(index + 1 < blockOffsets.size, + "Index is " + index + ", total size is " + blockOffsets.size) val offset = blockOffsets(index) - val length = - if (index + 1 < numBlocks) { - blockOffsets(index + 1) - offset - } else { - file.length() - offset - } + val length = blockOffsets(index + 1) - offset +// val length = +// if (index + 1 < numBlocks) { +// blockOffsets(index + 1) - offset +// } else { +// file.length() - offset +// } assert(length >= 0) Some(new FileSegment(file, offset, length)) } else {