Skip to content

Commit f2cab56

Browse files
advancedxycloud-fan
authored andcommitted
[SPARK-23040][CORE] Returns interruptible iterator for shuffle reader
## What changes were proposed in this pull request? Before this commit, a non-interruptible iterator is returned if aggregator or ordering is specified. This commit also ensures that sorter is closed even when task is cancelled(killed) in the middle of sorting. ## How was this patch tested? Add a unit test in JobCancellationSuite Author: Xianjin YE <[email protected]> Closes #20449 from advancedxy/SPARK-23040.
1 parent b0f422c commit f2cab56

File tree

2 files changed

+72
-2
lines changed

2 files changed

+72
-2
lines changed

core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
9494
}
9595

9696
// Sort the output if there is a sort ordering defined.
97-
dep.keyOrdering match {
97+
val resultIter = dep.keyOrdering match {
9898
case Some(keyOrd: Ordering[K]) =>
9999
// Create an ExternalSorter to sort the data.
100100
val sorter =
@@ -103,9 +103,16 @@ private[spark] class BlockStoreShuffleReader[K, C](
103103
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
104104
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
105105
context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
106+
// Use completion callback to stop sorter if task was finished/cancelled.
107+
context.addTaskCompletionListener(_ => {
108+
sorter.stop()
109+
})
106110
CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
107111
case None =>
108112
aggregatedIter
109113
}
114+
// Use another interruptible iterator here to support task cancellation as aggregator or(and)
115+
// sorter may have consumed previous interruptible iterator.
116+
new InterruptibleIterator[Product2[K, C]](context, resultIter)
110117
}
111118
}

core/src/test/scala/org/apache/spark/JobCancellationSuite.scala

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark
1919

2020
import java.util.concurrent.Semaphore
21+
import java.util.concurrent.atomic.AtomicInteger
2122

2223
import scala.concurrent.ExecutionContext.Implicits.global
2324
import scala.concurrent.Future
@@ -26,7 +27,7 @@ import scala.concurrent.duration._
2627
import org.scalatest.BeforeAndAfter
2728
import org.scalatest.Matchers
2829

29-
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart}
30+
import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted, SparkListenerTaskEnd, SparkListenerTaskStart}
3031
import org.apache.spark.util.ThreadUtils
3132

3233
/**
@@ -40,6 +41,10 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
4041
override def afterEach() {
4142
try {
4243
resetSparkContext()
44+
JobCancellationSuite.taskStartedSemaphore.drainPermits()
45+
JobCancellationSuite.taskCancelledSemaphore.drainPermits()
46+
JobCancellationSuite.twoJobsSharingStageSemaphore.drainPermits()
47+
JobCancellationSuite.executionOfInterruptibleCounter.set(0)
4348
} finally {
4449
super.afterEach()
4550
}
@@ -320,6 +325,62 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
320325
f2.get()
321326
}
322327

328+
test("interruptible iterator of shuffle reader") {
329+
// In this test case, we create a Spark job of two stages. The second stage is cancelled during
330+
// execution and a counter is used to make sure that the corresponding tasks are indeed
331+
// cancelled.
332+
import JobCancellationSuite._
333+
sc = new SparkContext("local[2]", "test interruptible iterator")
334+
335+
val taskCompletedSem = new Semaphore(0)
336+
337+
sc.addSparkListener(new SparkListener {
338+
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
339+
// release taskCancelledSemaphore when cancelTasks event has been posted
340+
if (stageCompleted.stageInfo.stageId == 1) {
341+
taskCancelledSemaphore.release(1000)
342+
}
343+
}
344+
345+
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
346+
if (taskEnd.stageId == 1) { // make sure tasks are completed
347+
taskCompletedSem.release()
348+
}
349+
}
350+
})
351+
352+
val f = sc.parallelize(1 to 1000).map { i => (i, i) }
353+
.repartitionAndSortWithinPartitions(new HashPartitioner(1))
354+
.mapPartitions { iter =>
355+
taskStartedSemaphore.release()
356+
iter
357+
}.foreachAsync { x =>
358+
if (x._1 >= 10) {
359+
// This block of code is partially executed. It will be blocked when x._1 >= 10 and the
360+
// next iteration will be cancelled if the source iterator is interruptible. Then in this
361+
// case, the maximum num of increment would be 10(|1...10|)
362+
taskCancelledSemaphore.acquire()
363+
}
364+
executionOfInterruptibleCounter.getAndIncrement()
365+
}
366+
367+
taskStartedSemaphore.acquire()
368+
// Job is cancelled when:
369+
// 1. task in reduce stage has been started, guaranteed by previous line.
370+
// 2. task in reduce stage is blocked after processing at most 10 records as
371+
// taskCancelledSemaphore is not released until cancelTasks event is posted
372+
// After job being cancelled, task in reduce stage will be cancelled and no more iteration are
373+
// executed.
374+
f.cancel()
375+
376+
val e = intercept[SparkException](f.get()).getCause
377+
assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed"))
378+
379+
// Make sure tasks are indeed completed.
380+
taskCompletedSem.acquire()
381+
assert(executionOfInterruptibleCounter.get() <= 10)
382+
}
383+
323384
def testCount() {
324385
// Cancel before launching any tasks
325386
{
@@ -381,7 +442,9 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft
381442

382443

383444
object JobCancellationSuite {
445+
// To avoid any headaches, reset these global variables in the companion class's afterEach block
384446
val taskStartedSemaphore = new Semaphore(0)
385447
val taskCancelledSemaphore = new Semaphore(0)
386448
val twoJobsSharingStageSemaphore = new Semaphore(0)
449+
val executionOfInterruptibleCounter = new AtomicInteger(0)
387450
}

0 commit comments

Comments
 (0)