1818package org .apache .spark
1919
2020import java .util .concurrent .Semaphore
21+ import java .util .concurrent .atomic .AtomicInteger
2122
2223import scala .concurrent .ExecutionContext .Implicits .global
2324import scala .concurrent .Future
@@ -26,7 +27,7 @@ import scala.concurrent.duration._
2627import org .scalatest .BeforeAndAfter
2728import org .scalatest .Matchers
2829
29- import org .apache .spark .scheduler .{SparkListener , SparkListenerTaskStart }
30+ import org .apache .spark .scheduler .{SparkListener , SparkListenerStageCompleted , SparkListenerTaskEnd , SparkListenerTaskStart }
3031import org .apache .spark .util .ThreadUtils
3132
3233/**
@@ -40,6 +41,10 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
4041 override def afterEach () {
4142 try {
4243 resetSparkContext()
44+ JobCancellationSuite .taskStartedSemaphore.drainPermits()
45+ JobCancellationSuite .taskCancelledSemaphore.drainPermits()
46+ JobCancellationSuite .twoJobsSharingStageSemaphore.drainPermits()
47+ JobCancellationSuite .executionOfInterruptibleCounter.set(0 )
4348 } finally {
4449 super .afterEach()
4550 }
@@ -320,6 +325,62 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
320325 f2.get()
321326 }
322327
328+ test(" interruptible iterator of shuffle reader" ) {
329+ // In this test case, we create a Spark job of two stages. The second stage is cancelled during
330+ // execution and a counter is used to make sure that the corresponding tasks are indeed
331+ // cancelled.
332+ import JobCancellationSuite ._
333+ sc = new SparkContext (" local[2]" , " test interruptible iterator" )
334+
335+ val taskCompletedSem = new Semaphore (0 )
336+
337+ sc.addSparkListener(new SparkListener {
338+ override def onStageCompleted (stageCompleted : SparkListenerStageCompleted ): Unit = {
339+ // release taskCancelledSemaphore when cancelTasks event has been posted
340+ if (stageCompleted.stageInfo.stageId == 1 ) {
341+ taskCancelledSemaphore.release(1000 )
342+ }
343+ }
344+
345+ override def onTaskEnd (taskEnd : SparkListenerTaskEnd ): Unit = {
346+ if (taskEnd.stageId == 1 ) { // make sure tasks are completed
347+ taskCompletedSem.release()
348+ }
349+ }
350+ })
351+
352+ val f = sc.parallelize(1 to 1000 ).map { i => (i, i) }
353+ .repartitionAndSortWithinPartitions(new HashPartitioner (1 ))
354+ .mapPartitions { iter =>
355+ taskStartedSemaphore.release()
356+ iter
357+ }.foreachAsync { x =>
358+ if (x._1 >= 10 ) {
359+ // This block of code is partially executed. It will be blocked when x._1 >= 10 and the
360+ // next iteration will be cancelled if the source iterator is interruptible. Then in this
361+ // case, the maximum num of increment would be 10(|1...10|)
362+ taskCancelledSemaphore.acquire()
363+ }
364+ executionOfInterruptibleCounter.getAndIncrement()
365+ }
366+
367+ taskStartedSemaphore.acquire()
368+ // Job is cancelled when:
369+ // 1. task in reduce stage has been started, guaranteed by previous line.
370+ // 2. task in reduce stage is blocked after processing at most 10 records as
371+ // taskCancelledSemaphore is not released until cancelTasks event is posted
372+ // After job being cancelled, task in reduce stage will be cancelled and no more iteration are
373+ // executed.
374+ f.cancel()
375+
376+ val e = intercept[SparkException ](f.get()).getCause
377+ assert(e.getMessage.contains(" cancelled" ) || e.getMessage.contains(" killed" ))
378+
379+ // Make sure tasks are indeed completed.
380+ taskCompletedSem.acquire()
381+ assert(executionOfInterruptibleCounter.get() <= 10 )
382+ }
383+
323384 def testCount () {
324385 // Cancel before launching any tasks
325386 {
@@ -381,7 +442,9 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
381442
382443
383444object JobCancellationSuite {
445+ // To avoid any headaches, reset these global variables in the companion class's afterEach block
384446 val taskStartedSemaphore = new Semaphore (0 )
385447 val taskCancelledSemaphore = new Semaphore (0 )
386448 val twoJobsSharingStageSemaphore = new Semaphore (0 )
449+ val executionOfInterruptibleCounter = new AtomicInteger (0 )
387450}
0 commit comments