Skip to content
Closed
7 changes: 5 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ class SparkContext(config: SparkConf) extends Logging {
private var _conf: SparkConf = _
private var _eventLogDir: Option[URI] = None
private var _eventLogCodec: Option[String] = None
private var _listenerBus: LiveListenerBus = _
private var _env: SparkEnv = _
private var _jobProgressListener: JobProgressListener = _
private var _statusTracker: SparkStatusTracker = _
Expand Down Expand Up @@ -247,7 +248,7 @@ class SparkContext(config: SparkConf) extends Logging {
def isStopped: Boolean = stopped.get()

// An asynchronous listener bus for Spark events
private[spark] val listenerBus = new LiveListenerBus(this)
private[spark] def listenerBus: LiveListenerBus = _listenerBus

// This function allows components created by SparkEnv to be mocked in unit tests:
private[spark] def createSparkEnv(
Expand Down Expand Up @@ -423,6 +424,8 @@ class SparkContext(config: SparkConf) extends Logging {

if (master == "yarn" && deployMode == "client") System.setProperty("SPARK_YARN_MODE", "true")

_listenerBus = new LiveListenerBus(_conf)

// "_jobProgressListener" should be set up before creating SparkEnv because when creating
// "SparkEnv", some messages will be posted to "listenerBus" and we should not miss them.
_jobProgressListener = new JobProgressListener(_conf)
Expand Down Expand Up @@ -2388,7 +2391,7 @@ class SparkContext(config: SparkConf) extends Logging {
}
}

listenerBus.start()
listenerBus.start(this, _env.metricsSystem)
_listenerBusStarted = true
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ package object config {
.checkValue(_ > 0, "The capacity of listener bus event queue must not be negative")
.createWithDefault(10000)

private[spark] val LISTENER_BUS_METRICS_MAX_LISTENER_CLASSES_TIMED =
ConfigBuilder("spark.scheduler.listenerbus.metrics.maxListenerClassesTimed")
.internal()
.intConf
.createWithDefault(128)

// This property sets the root namespace for metrics reporting
private[spark] val METRICS_NAMESPACE = ConfigBuilder("spark.metrics.namespace")
.stringConf
Expand Down
101 changes: 94 additions & 7 deletions core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,16 @@ package org.apache.spark.scheduler
import java.util.concurrent._
import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong}

import scala.collection.mutable
import scala.util.DynamicVariable

import org.apache.spark.SparkContext
import com.codahale.metrics.{Counter, Gauge, MetricRegistry, Timer}

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.metrics.source.Source
import org.apache.spark.util.Utils

/**
Expand All @@ -33,15 +39,20 @@ import org.apache.spark.util.Utils
* has started will events be actually propagated to all attached listeners. This listener bus
* is stopped when `stop()` is called, and it will drop further events after stopping.
*/
private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends SparkListenerBus {
private[spark] class LiveListenerBus(conf: SparkConf) extends SparkListenerBus {

self =>

import LiveListenerBus._

private var sparkContext: SparkContext = _

// Cap the capacity of the event queue so we get an explicit error (rather than
// an OOM exception) if it's perpetually being added to more quickly than it's being drained.
private lazy val eventQueue = new LinkedBlockingQueue[SparkListenerEvent](
sparkContext.conf.get(LISTENER_BUS_EVENT_QUEUE_CAPACITY))
private val eventQueue =
new LinkedBlockingQueue[SparkListenerEvent](conf.get(LISTENER_BUS_EVENT_QUEUE_CAPACITY))

private[spark] val metrics = new LiveListenerBusMetrics(conf, eventQueue)

// Indicate if `start()` is called
private val started = new AtomicBoolean(false)
Expand All @@ -67,6 +78,7 @@ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends Spa
setDaemon(true)
override def run(): Unit = Utils.tryOrStopSparkContext(sparkContext) {
LiveListenerBus.withinListenerThread.withValue(true) {
val timer = metrics.eventProcessingTime
while (true) {
eventLock.acquire()
self.synchronized {
Expand All @@ -82,7 +94,12 @@ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends Spa
}
return
}
postToAll(event)
val timerContext = timer.time()
try {
postToAll(event)
} finally {
timerContext.stop()
}
} finally {
self.synchronized {
processingEvent = false
Expand All @@ -93,16 +110,23 @@ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends Spa
}
}

override protected def getTimer(listener: SparkListenerInterface): Option[Timer] = {
metrics.getTimerForListenerClass(listener.getClass.asSubclass(classOf[SparkListenerInterface]))
}

/**
* Start sending events to attached listeners.
*
* This first sends out all buffered events posted before this listener bus has started, then
* listens for any additional events asynchronously while the listener bus is still running.
* This should only be called once.
*
* @param sc Used to stop the SparkContext in case the listener thread dies.
*/
def start(): Unit = {
def start(sc: SparkContext, metricsSystem: MetricsSystem): Unit = {
if (started.compareAndSet(false, true)) {
sparkContext = sc
metricsSystem.registerSource(metrics)
listenerThread.start()
} else {
throw new IllegalStateException(s"$name already started!")
Expand All @@ -115,12 +139,12 @@ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends Spa
logError(s"$name has already stopped! Dropping event $event")
return
}
metrics.numEventsPosted.inc()
val eventAdded = eventQueue.offer(event)
if (eventAdded) {
eventLock.release()
} else {
onDropEvent(event)
droppedEventsCounter.incrementAndGet()
}

val droppedEvents = droppedEventsCounter.get
Expand Down Expand Up @@ -200,6 +224,8 @@ private[spark] class LiveListenerBus(val sparkContext: SparkContext) extends Spa
* Note: `onDropEvent` can be called in any thread.
*/
def onDropEvent(event: SparkListenerEvent): Unit = {
metrics.numDroppedEvents.inc()
droppedEventsCounter.incrementAndGet()
if (logDroppedEvent.compareAndSet(false, true)) {
// Only log the following message once to avoid duplicated annoying logs.
logError("Dropping SparkListenerEvent because no remaining room in event queue. " +
Expand All @@ -217,3 +243,64 @@ private[spark] object LiveListenerBus {
val name = "SparkListenerBus"
}

private[spark] class LiveListenerBusMetrics(
conf: SparkConf,
queue: LinkedBlockingQueue[_])
extends Source with Logging {

override val sourceName: String = "LiveListenerBus"
override val metricRegistry: MetricRegistry = new MetricRegistry

/**
* The total number of events posted to the LiveListenerBus. This is a count of the total number
* of events which have been produced by the application and sent to the listener bus, NOT a
* count of the number of events which have been processed and delivered to listeners (or dropped
* without being delivered).
*/
val numEventsPosted: Counter = metricRegistry.counter(MetricRegistry.name("numEventsPosted"))

/**
* The total number of events that were dropped without being delivered to listeners.
*/
val numDroppedEvents: Counter = metricRegistry.counter(MetricRegistry.name("numEventsDropped"))

/**
* The amount of time taken to post a single event to all listeners.
*/
val eventProcessingTime: Timer = metricRegistry.timer(MetricRegistry.name("eventProcessingTime"))

/**
* The number of messages waiting in the queue.
*/
val queueSize: Gauge[Int] = {
metricRegistry.register(MetricRegistry.name("queueSize"), new Gauge[Int]{
override def getValue: Int = queue.size()
})
}

// Guarded by synchronization.
private val perListenerClassTimers = mutable.Map[String, Timer]()

/**
* Returns a timer tracking the processing time of the given listener class.
* events processed by that listener. This method is thread-safe.
*/
def getTimerForListenerClass(cls: Class[_ <: SparkListenerInterface]): Option[Timer] = {
synchronized {
val className = cls.getName
val maxTimed = conf.get(LISTENER_BUS_METRICS_MAX_LISTENER_CLASSES_TIMED)
perListenerClassTimers.get(className).orElse {
if (perListenerClassTimers.size == maxTimed) {
logError(s"Not measuring processing time for listener class $className because a " +
s"maximum of $maxTimed listener classes are already timed.")
None
} else {
perListenerClassTimers(className) =
metricRegistry.timer(MetricRegistry.name("listenerProcessingTime", className))
perListenerClassTimers.get(className)
}
}
}
}
}

33 changes: 28 additions & 5 deletions core/src/main/scala/org/apache/spark/util/ListenerBus.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,41 @@ import scala.collection.JavaConverters._
import scala.reflect.ClassTag
import scala.util.control.NonFatal

import com.codahale.metrics.Timer

import org.apache.spark.internal.Logging

/**
* An event bus which posts events to its listeners.
*/
private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging {

private[this] val listenersPlusTimers = new CopyOnWriteArrayList[(L, Option[Timer])]

// Marked `private[spark]` for access in tests.
private[spark] val listeners = new CopyOnWriteArrayList[L]
private[spark] def listeners = listenersPlusTimers.asScala.map(_._1).asJava

/**
* Returns a CodaHale metrics Timer for measuring the listener's event processing time.
* This method is intended to be overridden by subclasses.
*/
protected def getTimer(listener: L): Option[Timer] = None

/**
* Add a listener to listen events. This method is thread-safe and can be called in any thread.
*/
final def addListener(listener: L): Unit = {
listeners.add(listener)
listenersPlusTimers.add((listener, getTimer(listener)))
}

/**
* Remove a listener and it won't receive any events. This method is thread-safe and can be called
* in any thread.
*/
final def removeListener(listener: L): Unit = {
listeners.remove(listener)
listenersPlusTimers.asScala.find(_._1 eq listener).foreach { listenerAndTimer =>
listenersPlusTimers.remove(listenerAndTimer)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this is a CopyOnWriteArrayList, shall we just do a filter and create a new array?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the only reason that CopyOnWriteArrayList was used was for thread-safety and fast performance for readers interleaved with very rare mutations / writes. If we were to replace the array list then we'd need to add a synchronized to guard the listenersPlusTimers field itself.

Given the workload and access patterns here, I'm not sure that it's worth it to attempt to optimize this removeListener() method any further.

}
}

/**
Expand All @@ -56,14 +68,25 @@ private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging {
// JavaConverters can create a JIterableWrapper if we use asScala.
// However, this method will be called frequently. To avoid the wrapper cost, here we use
// Java Iterator directly.
val iter = listeners.iterator
val iter = listenersPlusTimers.iterator
while (iter.hasNext) {
val listener = iter.next()
val listenerAndMaybeTimer = iter.next()
val listener = listenerAndMaybeTimer._1
val maybeTimer = listenerAndMaybeTimer._2
val maybeTimerContext = if (maybeTimer.isDefined) {
maybeTimer.get.time()
} else {
null
}
try {
doPostEvent(listener, event)
} catch {
case NonFatal(e) =>
logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e)
} finally {
if (maybeTimerContext != null) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same. simpler with an option

maybeTimerContext.stop()
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ import scala.io.Source

import org.apache.hadoop.fs.Path
import org.json4s.jackson.JsonMethods._
import org.mockito.Mockito
import org.scalatest.BeforeAndAfter

import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.io._
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.util.{JsonProtocol, Utils}

/**
Expand Down Expand Up @@ -155,17 +157,18 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit
extraConf.foreach { case (k, v) => conf.set(k, v) }
val logName = compressionCodec.map("test-" + _).getOrElse("test")
val eventLogger = new EventLoggingListener(logName, None, testDirPath.toUri(), conf)
val listenerBus = new LiveListenerBus(sc)
val listenerBus = new LiveListenerBus(conf)
val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None,
125L, "Mickey", None)
val applicationEnd = SparkListenerApplicationEnd(1000L)

// A comprehensive test on JSON de/serialization of all events is in JsonProtocolSuite
eventLogger.start()
listenerBus.start()
listenerBus.start(Mockito.mock(classOf[SparkContext]), Mockito.mock(classOf[MetricsSystem]))
listenerBus.addListener(eventLogger)
listenerBus.postToAll(applicationStart)
listenerBus.postToAll(applicationEnd)
listenerBus.stop()
eventLogger.stop()

// Verify file contains exactly the two events logged
Expand Down
Loading