Skip to content

Commit 48f1ef4

Browse files
committed
update
1 parent 0733bfb commit 48f1ef4

File tree

2 files changed

+105
-34
lines changed

2 files changed

+105
-34
lines changed

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -341,14 +341,18 @@ class DAGScheduler(
341341
}
342342

343343
/**
344-
* Check to make sure we are not launching a barrier stage that contains PartitionPruningRDD,
345-
* which may launch tasks on partial partitions.
344+
* Check to make sure we don't launch a barrier stage with unsupported RDD chain pattern. The
345+
* following patterns are not supported:
346+
* 1. Ancestor RDDs that have different number of partitions from the resulting RDD (eg.
347+
* union()/coalesce()/first()/PartitionPruningRDD);
348+
* 2. An RDD that depends on multiple barrier RDDs (eg. barrierRdd1.zip(barrierRdd2)).
346349
*/
347-
private def checkBarrierStageWithPartitionPruningRDD(rdd: RDD[_]): Unit = {
348-
if (rdd.isBarrier() &&
349-
!traverseParentRDDsWithinStage(rdd, (r => !r.isInstanceOf[PartitionPruningRDD[_]]))) {
350-
throw new SparkException("Don't support run a barrier stage that contains " +
351-
"PartitionPruningRDD, because PartitionPruningRDD may launch tasks on partial partitions.")
350+
private def checkBarrierStageWithRDDChainPattern(rdd: RDD[_], numPartitions: Int): Unit = {
351+
val predicate: RDD[_] => Boolean = (r =>
352+
r.getNumPartitions == numPartitions && r.dependencies.filter(_.rdd.isBarrier()).size <= 1)
353+
if (rdd.isBarrier() && !traverseParentRDDsWithinStage(rdd, predicate)) {
354+
throw new SparkException(
355+
DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN)
352356
}
353357
}
354358

@@ -360,7 +364,7 @@ class DAGScheduler(
360364
*/
361365
def createShuffleMapStage(shuffleDep: ShuffleDependency[_, _, _], jobId: Int): ShuffleMapStage = {
362366
val rdd = shuffleDep.rdd
363-
checkBarrierStageWithPartitionPruningRDD(rdd)
367+
checkBarrierStageWithRDDChainPattern(rdd, rdd.getNumPartitions)
364368
val numTasks = rdd.partitions.length
365369
val parents = getOrCreateParentStages(rdd, jobId)
366370
val id = nextStageId.getAndIncrement()
@@ -389,7 +393,7 @@ class DAGScheduler(
389393
partitions: Array[Int],
390394
jobId: Int,
391395
callSite: CallSite): ResultStage = {
392-
checkBarrierStageWithPartitionPruningRDD(rdd)
396+
checkBarrierStageWithRDDChainPattern(rdd, partitions.toSet.size)
393397
val parents = getOrCreateParentStages(rdd, jobId)
394398
val id = nextStageId.getAndIncrement()
395399
val stage = new ResultStage(id, rdd, func, partitions, parents, jobId, callSite)
@@ -466,8 +470,8 @@ class DAGScheduler(
466470
}
467471

468472
/**
469-
* Traverse all the parent RDDs within the same stage with the given RDD, check whether all the
470-
* parent RDDs satisfy a given predicate.
473+
* Traverses the given RDD and its ancestors within the same stage and checks whether all of the
474+
* RDDs satisfy a given predicate.
471475
*/
472476
private def traverseParentRDDsWithinStage(rdd: RDD[_], predicate: RDD[_] => Boolean): Boolean = {
473477
val visited = new HashSet[RDD[_]]
@@ -481,7 +485,7 @@ class DAGScheduler(
481485
}
482486
visited += toVisit
483487
toVisit.dependencies.foreach {
484-
case shuffleDep: ShuffleDependency[_, _, _] =>
488+
case _: ShuffleDependency[_, _, _] =>
485489
// Not within the same stage with current rdd, do nothing.
486490
case dependency =>
487491
waitingForVisit.push(dependency.rdd)
@@ -1986,4 +1990,11 @@ private[spark] object DAGScheduler {
19861990

19871991
// Number of consecutive stage attempts allowed before a stage is aborted
19881992
val DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS = 4
1993+
1994+
// Error message when running a barrier stage that have unsupported RDD chain pattern.
1995+
val ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN =
1996+
"[SPARK-24820][SPARK-24821]: Barrier execution mode does not allow the following pattern of " +
1997+
"RDD chain within a barrier stage:\n1. Ancestor RDDs that have different number of " +
1998+
"partitions from the resulting RDD (eg. union()/coalesce()/first()/PartitionPruningRDD);\n" +
1999+
"2. An RDD that depends on multiple barrier RDDs (eg. barrierRdd1.zip(barrierRdd2))."
19892000
}

core/src/test/scala/org/apache/spark/BarrierStageOnSubmittedSuite.scala

Lines changed: 82 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,74 +20,134 @@ package org.apache.spark
2020
import scala.concurrent.duration._
2121
import scala.language.postfixOps
2222

23+
import org.scalatest.BeforeAndAfterEach
24+
2325
import org.apache.spark.rdd.{PartitionPruningRDD, RDD}
26+
import org.apache.spark.scheduler.DAGScheduler
2427
import org.apache.spark.util.ThreadUtils
2528

2629
/**
2730
* This test suite covers all the cases that shall fail fast on job submitted that contains one
2831
* of more barrier stages.
2932
*/
30-
class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext {
33+
class BarrierStageOnSubmittedSuite extends SparkFunSuite with BeforeAndAfterEach
34+
with LocalSparkContext {
35+
36+
override def beforeEach(): Unit = {
37+
super.beforeEach()
38+
39+
val conf = new SparkConf()
40+
.setMaster("local[4]")
41+
.setAppName("test")
42+
sc = new SparkContext(conf)
43+
}
3144

32-
private def testSubmitJob(sc: SparkContext, rdd: RDD[Int], message: String): Unit = {
45+
private def testSubmitJob(
46+
sc: SparkContext,
47+
rdd: RDD[Int],
48+
partitions: Option[Seq[Int]] = None,
49+
message: String): Unit = {
3350
val futureAction = sc.submitJob(
3451
rdd,
3552
(iter: Iterator[Int]) => iter.toArray,
36-
0 until rdd.partitions.length,
53+
partitions.getOrElse(0 until rdd.partitions.length),
3754
{ case (_, _) => return }: (Int, Array[Int]) => Unit,
3855
{ return }
3956
)
4057

4158
val error = intercept[SparkException] {
42-
ThreadUtils.awaitResult(futureAction, 1 seconds)
59+
ThreadUtils.awaitResult(futureAction, 5 seconds)
4360
}.getCause.getMessage
4461
assert(error.contains(message))
4562
}
4663

4764
test("submit a barrier ResultStage that contains PartitionPruningRDD") {
48-
val conf = new SparkConf()
49-
.setMaster("local[4]")
50-
.setAppName("test")
51-
sc = new SparkContext(conf)
52-
5365
val prunedRdd = new PartitionPruningRDD(sc.parallelize(1 to 10, 4), index => index > 1)
5466
val rdd = prunedRdd
5567
.barrier()
5668
.mapPartitions((iter, context) => iter)
5769
testSubmitJob(sc, rdd,
58-
"Don't support run a barrier stage that contains PartitionPruningRDD")
70+
message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN)
5971
}
6072

6173
test("submit a barrier ShuffleMapStage that contains PartitionPruningRDD") {
62-
val conf = new SparkConf()
63-
.setMaster("local[4]")
64-
.setAppName("test")
65-
sc = new SparkContext(conf)
66-
6774
val prunedRdd = new PartitionPruningRDD(sc.parallelize(1 to 10, 4), index => index > 1)
6875
val rdd = prunedRdd
6976
.barrier()
7077
.mapPartitions((iter, context) => iter)
7178
.repartition(2)
7279
.map(x => x + 1)
7380
testSubmitJob(sc, rdd,
74-
"Don't support run a barrier stage that contains PartitionPruningRDD")
81+
message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN)
7582
}
7683

7784
test("submit a barrier stage that doesn't contain PartitionPruningRDD") {
78-
val conf = new SparkConf()
79-
.setMaster("local[4]")
80-
.setAppName("test")
81-
sc = new SparkContext(conf)
82-
8385
val prunedRdd = new PartitionPruningRDD(sc.parallelize(1 to 10, 4), index => index > 1)
8486
val rdd = prunedRdd
8587
.repartition(2)
8688
.barrier()
8789
.mapPartitions((iter, context) => iter)
88-
8990
// Should be able to submit job and run successfully.
9091
val result = rdd.collect().sorted
9192
assert(result === Seq(6, 7, 8, 9, 10))
9293
}
94+
95+
test("submit a barrier stage with partial partitions") {
96+
val rdd = sc.parallelize(1 to 10, 4)
97+
.barrier()
98+
.mapPartitions((iter, context) => iter)
99+
testSubmitJob(sc, rdd, Some(Seq(1, 3)),
100+
message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN)
101+
}
102+
103+
test("submit a barrier stage with union()") {
104+
val rdd1 = sc.parallelize(1 to 10, 2)
105+
.barrier()
106+
.mapPartitions((iter, context) => iter)
107+
val rdd2 = sc.parallelize(1 to 20, 2)
108+
val rdd3 = rdd1
109+
.union(rdd2)
110+
.map(x => x * 2)
111+
// Fail the job on submit because the barrier RDD (rdd1) may be not assigned Task 0.
112+
testSubmitJob(sc, rdd3,
113+
message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN)
114+
}
115+
116+
test("submit a barrier stage with coalesce()") {
117+
val rdd = sc.parallelize(1 to 10, 4)
118+
.barrier()
119+
.mapPartitions((iter, context) => iter)
120+
.coalesce(1)
121+
// Fail the job on submit because the barrier RDD requires to run on 4 tasks, but the stage
122+
// only launches 1 task.
123+
testSubmitJob(sc, rdd,
124+
message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN)
125+
}
126+
127+
test("submit a barrier stage that contains an RDD that depends on multiple barrier RDDs") {
128+
val rdd1 = sc.parallelize(1 to 10, 4)
129+
.barrier()
130+
.mapPartitions((iter, context) => iter)
131+
val rdd2 = sc.parallelize(11 to 20, 4)
132+
.barrier()
133+
.mapPartitions((iter, context) => iter)
134+
val rdd3 = rdd1
135+
.zip(rdd2)
136+
.map(x => x._1 + x._2)
137+
testSubmitJob(sc, rdd3,
138+
message = DAGScheduler.ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN)
139+
}
140+
141+
test("submit a barrier stage with zip()") {
142+
val rdd1 = sc.parallelize(1 to 10, 4)
143+
.barrier()
144+
.mapPartitions((iter, context) => iter)
145+
val rdd2 = sc.parallelize(11 to 20, 4)
146+
val rdd3 = rdd1
147+
.zip(rdd2)
148+
.map(x => x._1 + x._2)
149+
// Should be able to submit job and run successfully.
150+
val result = rdd3.collect().sorted
151+
assert(result === Seq(12, 14, 16, 18, 20, 22, 24, 26, 28, 30))
152+
}
93153
}

0 commit comments

Comments
 (0)