Skip to content

Commit 3158fc3

Browse files
bersprocketsJoshRosencloud-fan
committed
[SPARK-23243][SPARK-20715][CORE][2.2] Fix RDD.repartition() data correctness issue
## What changes were proposed in this pull request? Back port of #22354 and #17955 to 2.2 (#22354 depends on methods introduced by #17955). ------- An alternative fix for #21698 When Spark rerun tasks for an RDD, there are 3 different behaviors: 1. determinate. Always return the same result with same order when rerun. 2. unordered. Returns same data set in random order when rerun. 3. indeterminate. Returns different result when rerun. Normally Spark doesn't need to care about it. Spark runs stages one by one, when a task is failed, just rerun it. Although the rerun task may return a different result, users will not be surprised. However, Spark may rerun a finished stage when seeing fetch failures. When this happens, Spark needs to rerun all the tasks of all the succeeding stages if the RDD output is indeterminate, because the input of the succeeding stages has been changed. If the RDD output is determinate, we only need to rerun the failed tasks of the succeeding stages, because the input doesn't change. If the RDD output is unordered, it's same as determinate, because shuffle partitioner is always deterministic(round-robin partitioner is not a shuffle partitioner that extends `org.apache.spark.Partitioner`), so the reducers will still get the same input data set. This PR fixed the failure handling for `repartition`, to avoid correctness issues. For `repartition`, it applies a stateful map function to generate a round-robin id, which is order sensitive and makes the RDD's output indeterminate. When the stage contains `repartition` reruns, we must also rerun all the tasks of all the succeeding stages. **future improvement:** 1. Currently we can't rollback and rerun a shuffle map stage, and just fail. We should fix it later. https://issues.apache.org/jira/browse/SPARK-25341 2. Currently we can't rollback and rerun a result stage, and just fail. We should fix it later. https://issues.apache.org/jira/browse/SPARK-25342 3. We should provide public API to allow users to tag the random level of the RDD's computing function. ## How was this patch tested? a new test case Closes #22382 from bersprockets/SPARK-23243-2.2. Lead-authored-by: Bruce Robbins <[email protected]> Co-authored-by: Josh Rosen <[email protected]> Co-authored-by: Wenchen Fan <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent af41ded commit 3158fc3

File tree

13 files changed

+750
-406
lines changed

13 files changed

+750
-406
lines changed

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 360 additions & 276 deletions
Large diffs are not rendered by default.

core/src/main/scala/org/apache/spark/Partitioner.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ import org.apache.spark.util.random.SamplingUtils
3232
/**
3333
* An object that defines how the elements in a key-value pair RDD are partitioned by key.
3434
* Maps each key to a partition ID, from 0 to `numPartitions - 1`.
35+
*
36+
* Note that, partitioner must be deterministic, i.e. it must return the same partition id given
37+
* the same partition key.
3538
*/
3639
abstract class Partitioner extends Serializable {
3740
def numPartitions: Int

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,8 +325,14 @@ private[spark] class Executor(
325325
throw new TaskKilledException(killReason.get)
326326
}
327327

328-
logDebug("Task " + taskId + "'s epoch is " + task.epoch)
329-
env.mapOutputTracker.updateEpoch(task.epoch)
328+
// The purpose of updating the epoch here is to invalidate executor map output status cache
329+
// in case FetchFailures have occurred. In local mode `env.mapOutputTracker` will be
330+
// MapOutputTrackerMaster and its cache invalidation is not based on epoch numbers so
331+
// we don't need to make any special calls here.
332+
if (!isLocal) {
333+
logDebug("Task " + taskId + "'s epoch is " + task.epoch)
334+
env.mapOutputTracker.asInstanceOf[MapOutputTrackerWorker].updateEpoch(task.epoch)
335+
}
330336

331337
// Run the actual task and measure its runtime.
332338
taskStart = System.currentTimeMillis()

core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,22 @@ import org.apache.spark.{Partition, TaskContext}
2323

2424
/**
2525
* An RDD that applies the provided function to every partition of the parent RDD.
26+
*
27+
* @param prev the parent RDD.
28+
* @param f The function used to map a tuple of (TaskContext, partition index, input iterator) to
29+
* an output iterator.
30+
* @param preservesPartitioning Whether the input function preserves the partitioner, which should
31+
* be `false` unless `prev` is a pair RDD and the input function
32+
* doesn't modify the keys.
33+
* @param isOrderSensitive whether or not the function is order-sensitive. If it's order
34+
* sensitive, it may return totally different result when the input order
35+
* is changed. Mostly stateful functions are order-sensitive.
2636
*/
2737
private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
2838
var prev: RDD[T],
2939
f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator)
30-
preservesPartitioning: Boolean = false)
40+
preservesPartitioning: Boolean = false,
41+
isOrderSensitive: Boolean = false)
3142
extends RDD[U](prev) {
3243

3344
override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None
@@ -41,4 +52,12 @@ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
4152
super.clearDependencies()
4253
prev = null
4354
}
55+
56+
override protected def getOutputDeterministicLevel = {
57+
if (isOrderSensitive && prev.outputDeterministicLevel == DeterministicLevel.UNORDERED) {
58+
DeterministicLevel.INDETERMINATE
59+
} else {
60+
super.getOutputDeterministicLevel
61+
}
62+
}
4463
}

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 94 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -461,8 +461,9 @@ abstract class RDD[T: ClassTag](
461461

462462
// include a shuffle step so that our upstream tasks are still distributed
463463
new CoalescedRDD(
464-
new ShuffledRDD[Int, T, T](mapPartitionsWithIndex(distributePartition),
465-
new HashPartitioner(numPartitions)),
464+
new ShuffledRDD[Int, T, T](
465+
mapPartitionsWithIndexInternal(distributePartition, isOrderSensitive = true),
466+
new HashPartitioner(numPartitions)),
466467
numPartitions,
467468
partitionCoalescer).values
468469
} else {
@@ -806,16 +807,21 @@ abstract class RDD[T: ClassTag](
806807
* serializable and don't require closure cleaning.
807808
*
808809
* @param preservesPartitioning indicates whether the input function preserves the partitioner,
809-
* which should be `false` unless this is a pair RDD and the input function doesn't modify
810-
* the keys.
810+
* which should be `false` unless this is a pair RDD and the input
811+
* function doesn't modify the keys.
812+
* @param isOrderSensitive whether or not the function is order-sensitive. If it's order
813+
* sensitive, it may return totally different result when the input order
814+
* is changed. Mostly stateful functions are order-sensitive.
811815
*/
812816
private[spark] def mapPartitionsWithIndexInternal[U: ClassTag](
813817
f: (Int, Iterator[T]) => Iterator[U],
814-
preservesPartitioning: Boolean = false): RDD[U] = withScope {
818+
preservesPartitioning: Boolean = false,
819+
isOrderSensitive: Boolean = false): RDD[U] = withScope {
815820
new MapPartitionsRDD(
816821
this,
817822
(context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter),
818-
preservesPartitioning)
823+
preservesPartitioning = preservesPartitioning,
824+
isOrderSensitive = isOrderSensitive)
819825
}
820826

821827
/**
@@ -1634,6 +1640,16 @@ abstract class RDD[T: ClassTag](
16341640
}
16351641
}
16361642

1643+
/**
1644+
* Return whether this RDD is reliably checkpointed and materialized.
1645+
*/
1646+
private[rdd] def isReliablyCheckpointed: Boolean = {
1647+
checkpointData match {
1648+
case Some(reliable: ReliableRDDCheckpointData[_]) if reliable.isCheckpointed => true
1649+
case _ => false
1650+
}
1651+
}
1652+
16371653
/**
16381654
* Gets the name of the directory to which this RDD was checkpointed.
16391655
* This is not defined if the RDD is checkpointed locally.
@@ -1838,6 +1854,63 @@ abstract class RDD[T: ClassTag](
18381854
def toJavaRDD() : JavaRDD[T] = {
18391855
new JavaRDD(this)(elementClassTag)
18401856
}
1857+
1858+
/**
1859+
* Returns the deterministic level of this RDD's output. Please refer to [[DeterministicLevel]]
1860+
* for the definition.
1861+
*
1862+
* By default, an reliably checkpointed RDD, or RDD without parents(root RDD) is DETERMINATE. For
1863+
* RDDs with parents, we will generate a deterministic level candidate per parent according to
1864+
* the dependency. The deterministic level of the current RDD is the deterministic level
1865+
* candidate that is deterministic least. Please override [[getOutputDeterministicLevel]] to
1866+
* provide custom logic of calculating output deterministic level.
1867+
*/
1868+
// TODO: make it public so users can set deterministic level to their custom RDDs.
1869+
// TODO: this can be per-partition. e.g. UnionRDD can have different deterministic level for
1870+
// different partitions.
1871+
private[spark] final lazy val outputDeterministicLevel: DeterministicLevel.Value = {
1872+
if (isReliablyCheckpointed) {
1873+
DeterministicLevel.DETERMINATE
1874+
} else {
1875+
getOutputDeterministicLevel
1876+
}
1877+
}
1878+
1879+
@DeveloperApi
1880+
protected def getOutputDeterministicLevel: DeterministicLevel.Value = {
1881+
val deterministicLevelCandidates = dependencies.map {
1882+
// The shuffle is not really happening, treat it like narrow dependency and assume the output
1883+
// deterministic level of current RDD is same as parent.
1884+
case dep: ShuffleDependency[_, _, _] if dep.rdd.partitioner.exists(_ == dep.partitioner) =>
1885+
dep.rdd.outputDeterministicLevel
1886+
1887+
case dep: ShuffleDependency[_, _, _] =>
1888+
if (dep.rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE) {
1889+
// If map output was indeterminate, shuffle output will be indeterminate as well
1890+
DeterministicLevel.INDETERMINATE
1891+
} else if (dep.keyOrdering.isDefined && dep.aggregator.isDefined) {
1892+
// if aggregator specified (and so unique keys) and key ordering specified - then
1893+
// consistent ordering.
1894+
DeterministicLevel.DETERMINATE
1895+
} else {
1896+
// In Spark, the reducer fetches multiple remote shuffle blocks at the same time, and
1897+
// the arrival order of these shuffle blocks are totally random. Even if the parent map
1898+
// RDD is DETERMINATE, the reduce RDD is always UNORDERED.
1899+
DeterministicLevel.UNORDERED
1900+
}
1901+
1902+
// For narrow dependency, assume the output deterministic level of current RDD is same as
1903+
// parent.
1904+
case dep => dep.rdd.outputDeterministicLevel
1905+
}
1906+
1907+
if (deterministicLevelCandidates.isEmpty) {
1908+
// By default we assume the root RDD is determinate.
1909+
DeterministicLevel.DETERMINATE
1910+
} else {
1911+
deterministicLevelCandidates.maxBy(_.id)
1912+
}
1913+
}
18411914
}
18421915

18431916

@@ -1891,3 +1964,18 @@ object RDD {
18911964
new DoubleRDDFunctions(rdd.map(x => num.toDouble(x)))
18921965
}
18931966
}
1967+
1968+
/**
1969+
* The deterministic level of RDD's output (i.e. what `RDD#compute` returns). This explains how
1970+
* the output will diff when Spark reruns the tasks for the RDD. There are 3 deterministic levels:
1971+
* 1. DETERMINATE: The RDD output is always the same data set in the same order after a rerun.
1972+
* 2. UNORDERED: The RDD output is always the same data set but the order can be different
1973+
* after a rerun.
1974+
* 3. INDETERMINATE. The RDD output can be different after a rerun.
1975+
*
1976+
* Note that, the output of an RDD usually relies on the parent RDDs. When the parent RDD's output
1977+
* is INDETERMINATE, it's very likely the RDD's output is also INDETERMINATE.
1978+
*/
1979+
private[spark] object DeterministicLevel extends Enumeration {
1980+
val DETERMINATE, UNORDERED, INDETERMINATE = Value
1981+
}

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 72 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ import org.apache.spark.executor.TaskMetrics
3838
import org.apache.spark.internal.Logging
3939
import org.apache.spark.network.util.JavaUtils
4040
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
41-
import org.apache.spark.rdd.{RDD, RDDCheckpointData}
41+
import org.apache.spark.rdd.{DeterministicLevel, RDD, RDDCheckpointData}
4242
import org.apache.spark.rpc.RpcTimeout
4343
import org.apache.spark.storage._
4444
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
@@ -328,25 +328,14 @@ class DAGScheduler(
328328
val numTasks = rdd.partitions.length
329329
val parents = getOrCreateParentStages(rdd, jobId)
330330
val id = nextStageId.getAndIncrement()
331-
val stage = new ShuffleMapStage(id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep)
331+
val stage = new ShuffleMapStage(
332+
id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep, mapOutputTracker)
332333

333334
stageIdToStage(id) = stage
334335
shuffleIdToMapStage(shuffleDep.shuffleId) = stage
335336
updateJobIdStageIdMaps(jobId, stage)
336337

337-
if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
338-
// A previously run stage generated partitions for this shuffle, so for each output
339-
// that's still available, copy information about that output location to the new stage
340-
// (so we don't unnecessarily re-compute that data).
341-
val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId)
342-
val locs = MapOutputTracker.deserializeMapStatuses(serLocs)
343-
(0 until locs.length).foreach { i =>
344-
if (locs(i) ne null) {
345-
// locs(i) will be null if missing
346-
stage.addOutputLoc(i, locs(i))
347-
}
348-
}
349-
} else {
338+
if (!mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
350339
// Kind of ugly: need to register RDDs with the cache and map output tracker here
351340
// since we can't do it in the RDD constructor because # of partitions is unknown
352341
logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")")
@@ -1240,7 +1229,8 @@ class DAGScheduler(
12401229
// The epoch of the task is acceptable (i.e., the task was launched after the most
12411230
// recent failure we're aware of for the executor), so mark the task's output as
12421231
// available.
1243-
shuffleStage.addOutputLoc(smt.partitionId, status)
1232+
mapOutputTracker.registerMapOutput(
1233+
shuffleStage.shuffleDep.shuffleId, smt.partitionId, status)
12441234
// Remove the task's partition from pending partitions. This may have already been
12451235
// done above, but will not have been done yet in cases where the task attempt was
12461236
// from an earlier attempt of the stage (i.e., not the attempt that's currently
@@ -1257,16 +1247,14 @@ class DAGScheduler(
12571247
logInfo("waiting: " + waitingStages)
12581248
logInfo("failed: " + failedStages)
12591249

1260-
// We supply true to increment the epoch number here in case this is a
1261-
// recomputation of the map outputs. In that case, some nodes may have cached
1262-
// locations with holes (from when we detected the error) and will need the
1263-
// epoch incremented to refetch them.
1264-
// TODO: Only increment the epoch number if this is not the first time
1265-
// we registered these map outputs.
1266-
mapOutputTracker.registerMapOutputs(
1267-
shuffleStage.shuffleDep.shuffleId,
1268-
shuffleStage.outputLocInMapOutputTrackerFormat(),
1269-
changeEpoch = true)
1250+
// This call to increment the epoch may not be strictly necessary, but it is retained
1251+
// for now in order to minimize the changes in behavior from an earlier version of the
1252+
// code. This existing behavior of always incrementing the epoch following any
1253+
// successful shuffle map stage completion may have benefits by causing unneeded
1254+
// cached map outputs to be cleaned up earlier on executors. In the future we can
1255+
// consider removing this call, but this will require some extra investigation.
1256+
// See https://github.com/apache/spark/pull/17955/files#r117385673 for more details.
1257+
mapOutputTracker.incrementEpoch()
12701258

12711259
clearCacheLocs()
12721260

@@ -1344,6 +1332,63 @@ class DAGScheduler(
13441332
failedStages += failedStage
13451333
failedStages += mapStage
13461334
if (noResubmitEnqueued) {
1335+
// If the map stage is INDETERMINATE, which means the map tasks may return
1336+
// different result when re-try, we need to re-try all the tasks of the failed
1337+
// stage and its succeeding stages, because the input data will be changed after the
1338+
// map tasks are re-tried.
1339+
// Note that, if map stage is UNORDERED, we are fine. The shuffle partitioner is
1340+
// guaranteed to be determinate, so the input data of the reducers will not change
1341+
// even if the map tasks are re-tried.
1342+
if (mapStage.rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE) {
1343+
// It's a little tricky to find all the succeeding stages of `failedStage`, because
1344+
// each stage only know its parents not children. Here we traverse the stages from
1345+
// the leaf nodes (the result stages of active jobs), and rollback all the stages
1346+
// in the stage chains that connect to the `failedStage`. To speed up the stage
1347+
// traversing, we collect the stages to rollback first. If a stage needs to
1348+
// rollback, all its succeeding stages need to rollback to.
1349+
val stagesToRollback = scala.collection.mutable.HashSet(failedStage)
1350+
1351+
def collectStagesToRollback(stageChain: List[Stage]): Unit = {
1352+
if (stagesToRollback.contains(stageChain.head)) {
1353+
stageChain.drop(1).foreach(s => stagesToRollback += s)
1354+
} else {
1355+
stageChain.head.parents.foreach { s =>
1356+
collectStagesToRollback(s :: stageChain)
1357+
}
1358+
}
1359+
}
1360+
1361+
def generateErrorMessage(stage: Stage): String = {
1362+
"A shuffle map stage with indeterminate output was failed and retried. " +
1363+
s"However, Spark cannot rollback the $stage to re-process the input data, " +
1364+
"and has to fail this job. Please eliminate the indeterminacy by " +
1365+
"checkpointing the RDD before repartition and try again."
1366+
}
1367+
1368+
activeJobs.foreach(job => collectStagesToRollback(job.finalStage :: Nil))
1369+
1370+
stagesToRollback.foreach {
1371+
case mapStage: ShuffleMapStage =>
1372+
val numMissingPartitions = mapStage.findMissingPartitions().length
1373+
if (numMissingPartitions < mapStage.numTasks) {
1374+
// TODO: support to rollback shuffle files.
1375+
// Currently the shuffle writing is "first write wins", so we can't re-run a
1376+
// shuffle map stage and overwrite existing shuffle files. We have to finish
1377+
// SPARK-8029 first.
1378+
abortStage(mapStage, generateErrorMessage(mapStage), None)
1379+
}
1380+
1381+
case resultStage: ResultStage if resultStage.activeJob.isDefined =>
1382+
val numMissingPartitions = resultStage.findMissingPartitions().length
1383+
if (numMissingPartitions < resultStage.numTasks) {
1384+
// TODO: support to rollback result tasks.
1385+
abortStage(resultStage, generateErrorMessage(resultStage), None)
1386+
}
1387+
1388+
case _ =>
1389+
}
1390+
}
1391+
13471392
// We expect one executor failure to trigger many FetchFailures in rapid succession,
13481393
// but all of those task failures can typically be handled by a single resubmission of
13491394
// the failed stage. We avoid flooding the scheduler's event queue with resubmit
@@ -1367,7 +1412,6 @@ class DAGScheduler(
13671412
}
13681413
// Mark the map whose fetch failed as broken in the map stage
13691414
if (mapId != -1) {
1370-
mapStage.removeOutputLoc(mapId, bmAddress)
13711415
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
13721416
}
13731417

@@ -1416,17 +1460,7 @@ class DAGScheduler(
14161460

14171461
if (filesLost || !env.blockManager.externalShuffleServiceEnabled) {
14181462
logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch))
1419-
// TODO: This will be really slow if we keep accumulating shuffle map stages
1420-
for ((shuffleId, stage) <- shuffleIdToMapStage) {
1421-
stage.removeOutputsOnExecutor(execId)
1422-
mapOutputTracker.registerMapOutputs(
1423-
shuffleId,
1424-
stage.outputLocInMapOutputTrackerFormat(),
1425-
changeEpoch = true)
1426-
}
1427-
if (shuffleIdToMapStage.isEmpty) {
1428-
mapOutputTracker.incrementEpoch()
1429-
}
1463+
mapOutputTracker.removeOutputsOnExecutor(execId)
14301464
clearCacheLocs()
14311465
}
14321466
} else {

0 commit comments

Comments
 (0)