diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index abece1ec0955..5326ef7098a8 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -22,7 +22,7 @@ import java.util.concurrent.TimeUnit import scala.collection.mutable -import com.codahale.metrics.{Metric, MetricRegistry} +import com.codahale.metrics._ import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.SparkConf @@ -68,13 +68,15 @@ import org.apache.spark.util.Utils * [options] represent the specific property of this source or sink. */ private[spark] class MetricsSystem private ( - val instance: String, conf: SparkConf) extends Logging { + val instance: String, + conf: SparkConf, + registry: MetricRegistry) + extends Logging { private[this] val metricsConfig = new MetricsConfig(conf) private val sinks = new mutable.ArrayBuffer[Sink] - private val sources = new mutable.ArrayBuffer[Source] - private val registry = new MetricRegistry() + private val sourcesWithListeners = new mutable.HashMap[Source, MetricRegistryListener] private var running: Boolean = false @@ -108,6 +110,9 @@ private[spark] class MetricsSystem private ( if (running) { sinks.foreach(_.stop) registry.removeMatching((_: String, _: Metric) => true) + sourcesWithListeners.synchronized { + sourcesWithListeners.keySet.foreach(removeSource) + } } else { logWarning("Stopping a MetricsSystem that is not running") } @@ -152,28 +157,24 @@ private[spark] class MetricsSystem private ( } else { defaultName } } - def getSourcesByName(sourceName: String): Seq[Source] = sources.synchronized { - sources.filter(_.sourceName == sourceName).toSeq + def getSourcesByName(sourceName: String): Seq[Source] = sourcesWithListeners.synchronized { + sourcesWithListeners.keySet.filter(_.sourceName == sourceName).toSeq } def registerSource(source: Source): Unit = { - sources.synchronized { - sources += source - } - try { - val regName = buildRegistryName(source) - registry.register(regName, source.metricRegistry) - } catch { - case e: IllegalArgumentException => logInfo("Metrics already registered", e) + val listener = new MetricsSystemListener(buildRegistryName(source)) + source.metricRegistry.addListener(listener) + sourcesWithListeners.synchronized { + sourcesWithListeners += source -> listener } } def removeSource(source: Source): Unit = { - sources.synchronized { - sources -= source - } val regName = buildRegistryName(source) registry.removeMatching((name: String, _: Metric) => name.startsWith(regName)) + sourcesWithListeners.synchronized { + sourcesWithListeners.remove(source).foreach(source.metricRegistry.removeListener) + } } private def registerSources(): Unit = { @@ -224,6 +225,49 @@ private[spark] class MetricsSystem private ( } } } + + private[spark] class MetricsSystemListener(prefix: String) + extends MetricRegistryListener { + def metricName(name: String): String = MetricRegistry.name(prefix, name) + + def registerMetric[T <: Metric](name: String, metric: T): Unit = { + try { + registry.register(metricName(name), metric) + } catch { + case e: IllegalArgumentException => logInfo("Metrics already registered", e) + } + } + + override def onHistogramAdded(name: String, histogram: Histogram): Unit = + registerMetric(name, histogram) + + override def onCounterAdded(name: String, counter: Counter): Unit = + registerMetric(name, counter) + + override def onMeterAdded(name: String, meter: Meter): Unit = + registerMetric(name, meter) + + override def onGaugeAdded(name: String, gauge: Gauge[_]): Unit = + registerMetric(name, gauge) + + override def onTimerAdded(name: String, timer: Timer): Unit = + registerMetric(name, timer) + + override def onHistogramRemoved(name: String): Unit = + registry.remove(metricName(name)) + + override def onGaugeRemoved(name: String): Unit = + registry.remove(metricName(name)) + + override def onMeterRemoved(name: String): Unit = + registry.remove(metricName(name)) + + override def onCounterRemoved(name: String): Unit = + registry.remove(metricName(name)) + + override def onTimerRemoved(name: String): Unit = + registry.remove(metricName(name)) + } } private[spark] object MetricsSystem { @@ -241,8 +285,11 @@ private[spark] object MetricsSystem { } } - def createMetricsSystem(instance: String, conf: SparkConf): MetricsSystem = { - new MetricsSystem(instance, conf) + def createMetricsSystem( + instance: String, + conf: SparkConf, + registry: MetricRegistry = new MetricRegistry): MetricsSystem = { + new MetricsSystem(instance, conf, registry) } } diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala index 31d8492510f0..7fef50806dd2 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.metrics +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import com.codahale.metrics.MetricRegistry +import com.codahale.metrics.{Gauge, MetricRegistry, MetricRegistryListener} import org.scalatest.{BeforeAndAfter, PrivateMethodTester} -import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.master.MasterSource import org.apache.spark.internal.config._ import org.apache.spark.metrics.sink.Sink @@ -31,21 +32,20 @@ import org.apache.spark.metrics.source.{Source, StaticSources} class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateMethodTester{ var filePath: String = _ var conf: SparkConf = null - var securityMgr: SecurityManager = null before { filePath = getClass.getClassLoader.getResource("test_metrics_system.properties").getFile conf = new SparkConf(false).set(METRICS_CONF, filePath) - securityMgr = new SecurityManager(conf) } test("MetricsSystem with default config") { val metricsSystem = MetricsSystem.createMetricsSystem("default", conf) metricsSystem.start() - val sources = PrivateMethod[ArrayBuffer[Source]](Symbol("sources")) + val sources = PrivateMethod[mutable.HashMap[Source, MetricRegistryListener]]( + Symbol("sourcesWithListeners")) val sinks = PrivateMethod[ArrayBuffer[Sink]](Symbol("sinks")) - assert(metricsSystem.invokePrivate(sources()).length === StaticSources.allSources.length) + assert(metricsSystem.invokePrivate(sources()).size === StaticSources.allSources.length) assert(metricsSystem.invokePrivate(sinks()).length === 0) assert(metricsSystem.getServletHandlers.nonEmpty) } @@ -53,16 +53,17 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM test("MetricsSystem with sources add") { val metricsSystem = MetricsSystem.createMetricsSystem("test", conf) metricsSystem.start() - val sources = PrivateMethod[ArrayBuffer[Source]](Symbol("sources")) + val sources = PrivateMethod[mutable.HashMap[Source, MetricRegistryListener]]( + Symbol("sourcesWithListeners")) val sinks = PrivateMethod[ArrayBuffer[Sink]](Symbol("sinks")) - assert(metricsSystem.invokePrivate(sources()).length === StaticSources.allSources.length) + assert(metricsSystem.invokePrivate(sources()).size === StaticSources.allSources.length) assert(metricsSystem.invokePrivate(sinks()).length === 1) assert(metricsSystem.getServletHandlers.nonEmpty) val source = new MasterSource(null) metricsSystem.registerSource(source) - assert(metricsSystem.invokePrivate(sources()).length === StaticSources.allSources.length + 1) + assert(metricsSystem.invokePrivate(sources()).size === StaticSources.allSources.length + 1) } test("MetricsSystem with Driver instance") { @@ -269,4 +270,24 @@ class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateM assert(metricName === source.sourceName) } + test("MetricsSystem registers dynamically added metrics") { + val registry = new MetricRegistry() + val source = new Source { + override val sourceName = "dummySource" + override val metricRegistry = new MetricRegistry() + } + + val instanceName = "testInstance" + val metricsSystem = MetricsSystem.createMetricsSystem( + instanceName, conf, registry) + metricsSystem.registerSource(source) + assert(!registry.getNames.contains("dummySource.newMetric"), "Metric shouldn't be registered") + + source.metricRegistry.register("newMetric", new Gauge[Integer] { + override def getValue: Integer = 1 + }) + assert(registry.getNames.contains("dummySource.newMetric"), + "Metric should have been registered") + } + } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 29eb1db63862..0e4d8d7a0ad7 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -23,9 +23,10 @@ import java.util.Locale import java.util.concurrent.{CountDownLatch, TimeUnit} import java.util.concurrent.atomic.AtomicInteger -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable import scala.collection.mutable.Queue +import com.codahale.metrics.MetricRegistryListener import org.apache.commons.io.FileUtils import org.scalatest.{Assertions, PrivateMethodTester} import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} @@ -1027,8 +1028,10 @@ object testPackage extends Assertions { * This includes methods to access private methods and fields in StreamingContext and MetricsSystem */ private object StreamingContextSuite extends PrivateMethodTester { - private val _sources = PrivateMethod[ArrayBuffer[Source]](Symbol("sources")) - private def getSources(metricsSystem: MetricsSystem): ArrayBuffer[Source] = { + private val _sources = + PrivateMethod[mutable.HashMap[Source, MetricRegistryListener]](Symbol("sourcesWithListeners")) + private def getSources(metricsSystem: MetricsSystem): + mutable.HashMap[Source, MetricRegistryListener] = { metricsSystem.invokePrivate(_sources()) } private val _streamingSource = PrivateMethod[StreamingSource](Symbol("streamingSource"))