diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 2c10779f2b893..dbb8dd198b6a3 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -29,7 +29,7 @@ import java.util.UUID.randomUUID import scala.collection.JavaConverters._ import scala.collection.{Map, Set} import scala.collection.generic.Growable -import scala.collection.mutable.HashMap +import scala.collection.mutable.{ListBuffer, HashMap} import scala.reflect.{ClassTag, classTag} import scala.util.control.NonFatal @@ -241,6 +241,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private var _jars: Seq[String] = _ private var _files: Seq[String] = _ private var _shutdownHookRef: AnyRef = _ + private val _stopHooks = new ListBuffer[() => Unit]() /* ------------------------------------------------------------------------------------- * | Accessors and public fields. These provide access to the internal state of the | @@ -1618,6 +1619,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli listenerBus.post(SparkListenerUnpersistRDD(rddId)) } + /** + * Adds a stop hook which can be used to clean up additional resource. This is called when the + * sparkContext is being stopped. + */ + private[spark] def addStopHook(hook: () => Unit): Unit = { + _stopHooks += hook + } + /** * Adds a JAR dependency for all tasks to be executed on this SparkContext in the future. * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported @@ -1764,10 +1773,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Unset YARN mode system env variable, to allow switching between cluster types. System.clearProperty("SPARK_YARN_MODE") SparkContext.clearActiveContext() + _stopHooks.foreach(hook => Utils.tryLogNonFatalError { + hook() + }) logInfo("Successfully stopped SparkContext") } - /** * Get Spark's home location from either a value set through the constructor, * or the spark.home Java property, or the SPARK_HOME environment variable diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 1c2ac5f6f11bf..b34b01610d5cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -78,6 +78,12 @@ class SQLContext private[sql]( } def this(sparkContext: JavaSparkContext) = this(sparkContext.sc) + sparkContext.addStopHook(() => { + SQLContext.clearInstantiatedContext() + SQLContext.clearActive() + SQLContext.clearSqlListener() + }) + // If spark.sql.allowMultipleContexts is true, we will throw an exception if a user // wants to create a new root SQLContext (a SLQContext that is not created by newSession). private val allowMultipleContexts = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index f93d081d0c30e..ff585db5d3586 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -343,6 +343,8 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly val sc = new SparkContext(conf) try { + // Clear the sql listener created by a previous test suite. + SQLContext.clearSqlListener() val sqlContext = new SQLContext(sc) import sqlContext.implicits._ // Run 100 successful executions and 100 failed executions.