@@ -20,74 +20,134 @@ package org.apache.spark
2020import scala .concurrent .duration ._
2121import scala .language .postfixOps
2222
23+ import org .scalatest .BeforeAndAfterEach
24+
2325import org .apache .spark .rdd .{PartitionPruningRDD , RDD }
26+ import org .apache .spark .scheduler .DAGScheduler
2427import org .apache .spark .util .ThreadUtils
2528
2629/**
2730 * This test suite covers all the cases that shall fail fast on job submitted that contains one
2831 * of more barrier stages.
2932 */
30- class BarrierStageOnSubmittedSuite extends SparkFunSuite with LocalSparkContext {
33+ class BarrierStageOnSubmittedSuite extends SparkFunSuite with BeforeAndAfterEach
34+ with LocalSparkContext {
35+
36+ override def beforeEach (): Unit = {
37+ super .beforeEach()
38+
39+ val conf = new SparkConf ()
40+ .setMaster(" local[4]" )
41+ .setAppName(" test" )
42+ sc = new SparkContext (conf)
43+ }
3144
32- private def testSubmitJob (sc : SparkContext , rdd : RDD [Int ], message : String ): Unit = {
45+ private def testSubmitJob (
46+ sc : SparkContext ,
47+ rdd : RDD [Int ],
48+ partitions : Option [Seq [Int ]] = None ,
49+ message : String ): Unit = {
3350 val futureAction = sc.submitJob(
3451 rdd,
3552 (iter : Iterator [Int ]) => iter.toArray,
36- 0 until rdd.partitions.length,
53+ partitions.getOrElse( 0 until rdd.partitions.length) ,
3754 { case (_, _) => return }: (Int , Array [Int ]) => Unit ,
3855 { return }
3956 )
4057
4158 val error = intercept[SparkException ] {
42- ThreadUtils .awaitResult(futureAction, 1 seconds)
59+ ThreadUtils .awaitResult(futureAction, 5 seconds)
4360 }.getCause.getMessage
4461 assert(error.contains(message))
4562 }
4663
4764 test(" submit a barrier ResultStage that contains PartitionPruningRDD" ) {
48- val conf = new SparkConf ()
49- .setMaster(" local[4]" )
50- .setAppName(" test" )
51- sc = new SparkContext (conf)
52-
5365 val prunedRdd = new PartitionPruningRDD (sc.parallelize(1 to 10 , 4 ), index => index > 1 )
5466 val rdd = prunedRdd
5567 .barrier()
5668 .mapPartitions((iter, context) => iter)
5769 testSubmitJob(sc, rdd,
58- " Don't support run a barrier stage that contains PartitionPruningRDD " )
70+ message = DAGScheduler . ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN )
5971 }
6072
6173 test(" submit a barrier ShuffleMapStage that contains PartitionPruningRDD" ) {
62- val conf = new SparkConf ()
63- .setMaster(" local[4]" )
64- .setAppName(" test" )
65- sc = new SparkContext (conf)
66-
6774 val prunedRdd = new PartitionPruningRDD (sc.parallelize(1 to 10 , 4 ), index => index > 1 )
6875 val rdd = prunedRdd
6976 .barrier()
7077 .mapPartitions((iter, context) => iter)
7178 .repartition(2 )
7279 .map(x => x + 1 )
7380 testSubmitJob(sc, rdd,
74- " Don't support run a barrier stage that contains PartitionPruningRDD " )
81+ message = DAGScheduler . ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN )
7582 }
7683
7784 test(" submit a barrier stage that doesn't contain PartitionPruningRDD" ) {
78- val conf = new SparkConf ()
79- .setMaster(" local[4]" )
80- .setAppName(" test" )
81- sc = new SparkContext (conf)
82-
8385 val prunedRdd = new PartitionPruningRDD (sc.parallelize(1 to 10 , 4 ), index => index > 1 )
8486 val rdd = prunedRdd
8587 .repartition(2 )
8688 .barrier()
8789 .mapPartitions((iter, context) => iter)
88-
8990 // Should be able to submit job and run successfully.
9091 val result = rdd.collect().sorted
9192 assert(result === Seq (6 , 7 , 8 , 9 , 10 ))
9293 }
94+
95+ test(" submit a barrier stage with partial partitions" ) {
96+ val rdd = sc.parallelize(1 to 10 , 4 )
97+ .barrier()
98+ .mapPartitions((iter, context) => iter)
99+ testSubmitJob(sc, rdd, Some (Seq (1 , 3 )),
100+ message = DAGScheduler .ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN )
101+ }
102+
103+ test(" submit a barrier stage with union()" ) {
104+ val rdd1 = sc.parallelize(1 to 10 , 2 )
105+ .barrier()
106+ .mapPartitions((iter, context) => iter)
107+ val rdd2 = sc.parallelize(1 to 20 , 2 )
108+ val rdd3 = rdd1
109+ .union(rdd2)
110+ .map(x => x * 2 )
111+ // Fail the job on submit because the barrier RDD (rdd1) may be not assigned Task 0.
112+ testSubmitJob(sc, rdd3,
113+ message = DAGScheduler .ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN )
114+ }
115+
116+ test(" submit a barrier stage with coalesce()" ) {
117+ val rdd = sc.parallelize(1 to 10 , 4 )
118+ .barrier()
119+ .mapPartitions((iter, context) => iter)
120+ .coalesce(1 )
121+ // Fail the job on submit because the barrier RDD requires to run on 4 tasks, but the stage
122+ // only launches 1 task.
123+ testSubmitJob(sc, rdd,
124+ message = DAGScheduler .ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN )
125+ }
126+
127+ test(" submit a barrier stage that contains an RDD that depends on multiple barrier RDDs" ) {
128+ val rdd1 = sc.parallelize(1 to 10 , 4 )
129+ .barrier()
130+ .mapPartitions((iter, context) => iter)
131+ val rdd2 = sc.parallelize(11 to 20 , 4 )
132+ .barrier()
133+ .mapPartitions((iter, context) => iter)
134+ val rdd3 = rdd1
135+ .zip(rdd2)
136+ .map(x => x._1 + x._2)
137+ testSubmitJob(sc, rdd3,
138+ message = DAGScheduler .ERROR_MESSAGE_RUN_BARRIER_WITH_UNSUPPORTED_RDD_CHAIN_PATTERN )
139+ }
140+
141+ test(" submit a barrier stage with zip()" ) {
142+ val rdd1 = sc.parallelize(1 to 10 , 4 )
143+ .barrier()
144+ .mapPartitions((iter, context) => iter)
145+ val rdd2 = sc.parallelize(11 to 20 , 4 )
146+ val rdd3 = rdd1
147+ .zip(rdd2)
148+ .map(x => x._1 + x._2)
149+ // Should be able to submit job and run successfully.
150+ val result = rdd3.collect().sorted
151+ assert(result === Seq (12 , 14 , 16 , 18 , 20 , 22 , 24 , 26 , 28 , 30 ))
152+ }
93153}
0 commit comments