diff --git a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala index 11e2c475d9b4..c1e534a4faa9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala @@ -20,6 +20,8 @@ package org.apache.spark.scheduler import java.util.concurrent.LinkedBlockingQueue import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} +import scala.collection.JavaConverters._ + import com.codahale.metrics.{Gauge, Timer} import org.apache.spark.{SparkConf, SparkContext} @@ -95,18 +97,26 @@ private class AsyncEventQueue( } private def dispatch(): Unit = LiveListenerBus.withinListenerThread.withValue(true) { - var next: SparkListenerEvent = eventQueue.take() - while (next != POISON_PILL) { - val ctx = processingTime.time() - try { - super.postToAll(next) - } finally { - ctx.stop() + try { + var next: SparkListenerEvent = eventQueue.take() + while (next != POISON_PILL) { + val ctx = processingTime.time() + try { + super.postToAll(next) + } finally { + ctx.stop() + } + eventCount.decrementAndGet() + next = eventQueue.take() } eventCount.decrementAndGet() - next = eventQueue.take() + } catch { + case ie: InterruptedException => + logInfo(s"Interrupted while dispatch event in queue $name. " + + s"Removing all its listeners: " + + s"${listeners.asScala.map(Utils.getFormattedClassName(_)).mkString(",")}.", ie) + listeners.asScala.foreach(removeListenerOnError) } - eventCount.decrementAndGet() } override protected def getTimer(listener: SparkListenerInterface): Option[Timer] = { diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 7221623f89e1..7b016580c34c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -529,6 +529,52 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match } } + test("interrupt within dispatch is handled correctly") { + val conf = new SparkConf(false) + .set(LISTENER_BUS_EVENT_QUEUE_CAPACITY, 5) + val bus = new LiveListenerBus(conf) + val counter1 = new BasicJobCounter() + val counter2 = new BasicJobCounter() + val interruptingListener1 = new AsyncInterruptingListener() + val interruptingListener2 = new AsyncInterruptingListener() + bus.addToSharedQueue(counter1) + bus.addToSharedQueue(interruptingListener1) + bus.addToStatusQueue(counter2) + bus.addToEventLogQueue(interruptingListener2) + assert(bus.activeQueues() === Set(SHARED_QUEUE, APP_STATUS_QUEUE, EVENT_LOG_QUEUE)) + assert(bus.findListenersByClass[BasicJobCounter]().size === 2) + assert(bus.findListenersByClass[AsyncInterruptingListener]().size === 2) + + sc = new SparkContext("local", "SparkListenerSuite", conf) + bus.start(sc, mockMetricsSystem) + + // after we post one event, both interrupting listeners should get removed, and the + // event log queue should be removed + bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) + bus.waitUntilEmpty() + interruptingListener1.sleep = false + interruptingListener2.sleep = false + // wait enough time for bus to remove interrupted queues + Thread.sleep(1000) + // SparkContext should not be stopped + assert(sc.isStopped === false) + assert(bus.activeQueues() === Set(APP_STATUS_QUEUE)) + assert(bus.findListenersByClass[BasicJobCounter]().size === 1) + assert(bus.findListenersByClass[AsyncInterruptingListener]().size === 0) + assert(counter1.count === 1) + assert(counter2.count === 1) + + // posting more events should be fine, they'll just get processed from the OK queue. + (0 until 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } + bus.waitUntilEmpty() + assert(counter1.count === 1) + assert(counter2.count === 6) + + // Make sure stopping works -- this requires putting a poison pill in all active queues, which + // would fail if our interrupted queue was still active, as its queue would be full. + bus.stop() + } + test("event queue size can be configued through spark conf") { // configure the shared queue size to be 1, event log queue size to be 2, // and listner bus event queue size to be 5 @@ -627,6 +673,25 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match } } } + + /** + * A simple listener that interrupts on job end asynchronously. + */ + private class AsyncInterruptingListener extends SparkListener { + @volatile var sleep = true + + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + val listenerThread = Thread.currentThread() + new Thread(new Runnable { + override def run(): Unit = { + while (sleep) { + Thread.sleep(10) + } + listenerThread.interrupt() + } + }).start() + } + } } // These classes can't be declared inside of the SparkListenerSuite class because we don't want