diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 565654e7f03b..3016ffbed63f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3563,6 +3563,48 @@ def test_query_execution_listener_on_collect_with_arrow(self): "The callback from the query execution listener should be called after 'toPandas'") +class SparkExtensionsTest(unittest.TestCase): + # These tests are separate because it uses 'spark.sql.extensions' which is + # static and immutable. This can't be set or unset, for example, via `spark.conf`. + + @classmethod + def setUpClass(cls): + import glob + from pyspark.find_spark_home import _find_spark_home + + SPARK_HOME = _find_spark_home() + filename_pattern = ( + "sql/core/target/scala-*/test-classes/org/apache/spark/sql/" + "SparkSessionExtensionSuite.class") + if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)): + raise unittest.SkipTest( + "'org.apache.spark.sql.SparkSessionExtensionSuite' is not " + "available. Will skip the related tests.") + + # Note that 'spark.sql.extensions' is a static immutable configuration. + cls.spark = SparkSession.builder \ + .master("local[4]") \ + .appName(cls.__name__) \ + .config( + "spark.sql.extensions", + "org.apache.spark.sql.MyExtensions") \ + .getOrCreate() + + @classmethod + def tearDownClass(cls): + cls.spark.stop() + + def test_use_custom_class_for_extensions(self): + self.assertTrue( + self.spark._jsparkSession.sessionState().planner().strategies().contains( + self.spark._jvm.org.apache.spark.sql.MySparkStrategy(self.spark._jsparkSession)), + "MySparkStrategy not found in active planner strategies") + self.assertTrue( + self.spark._jsparkSession.sessionState().analyzer().extendedResolutionRules().contains( + self.spark._jvm.org.apache.spark.sql.MyRule(self.spark._jsparkSession)), + "MyRule not found in extended resolution rules") + + class SparkSessionTests(PySparkTestCase): # This test is separate because it's closely related with session's start and stop. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 565042fcf762..1154f6c06717 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -84,8 +84,17 @@ class SparkSession private( // The call site where this SparkSession was constructed. private val creationSite: CallSite = Utils.getCallSite() + /** + * Constructor used in Pyspark. Contains explicit application of Spark Session Extensions + * which otherwise only occurs during getOrCreate. We cannot add this to the default constructor + * since that would cause every new session to reinvoke Spark Session Extensions on the currently + * running extensions. + */ private[sql] def this(sc: SparkContext) { - this(sc, None, None, new SparkSessionExtensions) + this(sc, None, None, + SparkSession.applyExtensions( + sc.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS), + new SparkSessionExtensions)) } sparkContext.assertNotStopped() @@ -935,23 +944,9 @@ object SparkSession extends Logging { // Do not update `SparkConf` for existing `SparkContext`, as it's shared by all sessions. } - // Initialize extensions if the user has defined a configurator class. - val extensionConfOption = sparkContext.conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS) - if (extensionConfOption.isDefined) { - val extensionConfClassName = extensionConfOption.get - try { - val extensionConfClass = Utils.classForName(extensionConfClassName) - val extensionConf = extensionConfClass.newInstance() - .asInstanceOf[SparkSessionExtensions => Unit] - extensionConf(extensions) - } catch { - // Ignore the error if we cannot find the class or when the class has the wrong type. - case e @ (_: ClassCastException | - _: ClassNotFoundException | - _: NoClassDefFoundError) => - logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e) - } - } + applyExtensions( + sparkContext.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS), + extensions) session = new SparkSession(sparkContext, None, None, extensions) options.foreach { case (k, v) => session.initialSessionOptions.put(k, v) } @@ -1136,4 +1131,29 @@ object SparkSession extends Logging { SparkSession.clearDefaultSession() } } + + /** + * Initialize extensions for given extension classname. This class will be applied to the + * extensions passed into this function. + */ + private def applyExtensions( + extensionOption: Option[String], + extensions: SparkSessionExtensions): SparkSessionExtensions = { + if (extensionOption.isDefined) { + val extensionConfClassName = extensionOption.get + try { + val extensionConfClass = Utils.classForName(extensionConfClassName) + val extensionConf = extensionConfClass.newInstance() + .asInstanceOf[SparkSessionExtensions => Unit] + extensionConf(extensions) + } catch { + // Ignore the error if we cannot find the class or when the class has the wrong type. + case e@(_: ClassCastException | + _: ClassNotFoundException | + _: NoClassDefFoundError) => + logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e) + } + } + extensions + } }