Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
}

// Sort the output if there is a sort ordering defined.
dep.keyOrdering match {
val resultIter = dep.keyOrdering match {
case Some(keyOrd: Ordering[K]) =>
// Create an ExternalSorter to sort the data.
val sorter =
Expand All @@ -104,9 +104,16 @@ private[spark] class BlockStoreShuffleReader[K, C](
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
// Use completion callback to stop sorter if task was finished/cancelled.
context.addTaskCompletionListener(_ => {
sorter.stop()
})
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
case None =>
aggregatedIter
}
// Use another interruptible iterator here to support task cancellation as aggregator or(and)
// sorter may have consumed previous interruptible iterator.
new InterruptibleIterator[Product2[K, C]](context, resultIter)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a chance that resultIter is already an InterruptibleIterator, and we should not double wrap it. Can you send a followup PR to fix this? then we can backport them to 2.3 together.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do

}
}
65 changes: 64 additions & 1 deletion core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark

import java.util.concurrent.Semaphore
import java.util.concurrent.atomic.AtomicInteger

import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Future
Expand All @@ -26,7 +27,7 @@ import scala.concurrent.duration._
import org.scalatest.BeforeAndAfter
import org.scalatest.Matchers

import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart}
import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted, SparkListenerTaskEnd, SparkListenerTaskStart}
import org.apache.spark.util.ThreadUtils

/**
Expand All @@ -40,6 +41,10 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
override def afterEach() {
try {
resetSparkContext()
JobCancellationSuite.taskStartedSemaphore.drainPermits()
JobCancellationSuite.taskCancelledSemaphore.drainPermits()
JobCancellationSuite.twoJobsSharingStageSemaphore.drainPermits()
JobCancellationSuite.executionOfInterruptibleCounter.set(0)
} finally {
super.afterEach()
}
Expand Down Expand Up @@ -320,6 +325,62 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
f2.get()
}

test("interruptible iterator of shuffle reader") {
// In this test case, we create a Spark job of two stages. The second stage is cancelled during
// execution and a counter is used to make sure that the corresponding tasks are indeed
// cancelled.
import JobCancellationSuite._
sc = new SparkContext("local[2]", "test interruptible iterator")

val taskCompletedSem = new Semaphore(0)

sc.addSparkListener(new SparkListener {
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
// release taskCancelledSemaphore when cancelTasks event has been posted
if (stageCompleted.stageInfo.stageId == 1) {
taskCancelledSemaphore.release(1000)
}
}

override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
if (taskEnd.stageId == 1) { // make sure tasks are completed
taskCompletedSem.release()
}
}
})

val f = sc.parallelize(1 to 1000).map { i => (i, i) }
.repartitionAndSortWithinPartitions(new HashPartitioner(1))
.mapPartitions { iter =>
taskStartedSemaphore.release()
iter
}.foreachAsync { x =>
if (x._1 >= 10) {
// This block of code is partially executed. It will be blocked when x._1 >= 10 and the
// next iteration will be cancelled if the source iterator is interruptible. Then in this
// case, the maximum num of increment would be 10(|1...10|)
taskCancelledSemaphore.acquire()
}
executionOfInterruptibleCounter.getAndIncrement()
}

taskStartedSemaphore.acquire()
// Job is cancelled when:
// 1. task in reduce stage has been started, guaranteed by previous line.
// 2. task in reduce stage is blocked after processing at most 10 records as
// taskCancelledSemaphore is not released until cancelTasks event is posted
// After job being cancelled, task in reduce stage will be cancelled and no more iteration are
// executed.
f.cancel()

val e = intercept[SparkException](f.get()).getCause
assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))

// Make sure tasks are indeed completed.
taskCompletedSem.acquire()
assert(executionOfInterruptibleCounter.get() <= 10)
}

def testCount() {
// Cancel before launching any tasks
{
Expand Down Expand Up @@ -381,7 +442,9 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft


object JobCancellationSuite {
// To avoid any headaches, reset these global variables in the companion class's afterEach block
val taskStartedSemaphore = new Semaphore(0)
val taskCancelledSemaphore = new Semaphore(0)
val twoJobsSharingStageSemaphore = new Semaphore(0)
val executionOfInterruptibleCounter = new AtomicInteger(0)
}