Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3563,6 +3563,48 @@ def test_query_execution_listener_on_collect_with_arrow(self):
"The callback from the query execution listener should be called after 'toPandas'")


class SparkExtensionsTest(unittest.TestCase):
# These tests are separate because it uses 'spark.sql.extensions' which is
# static and immutable. This can't be set or unset, for example, via `spark.conf`.

@classmethod
def setUpClass(cls):
import glob
from pyspark.find_spark_home import _find_spark_home

SPARK_HOME = _find_spark_home()
filename_pattern = (
"sql/core/target/scala-*/test-classes/org/apache/spark/sql/"
"SparkSessionExtensionSuite.class")
if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)):
raise unittest.SkipTest(
"'org.apache.spark.sql.SparkSessionExtensionSuite' is not "
"available. Will skip the related tests.")

# Note that 'spark.sql.extensions' is a static immutable configuration.
cls.spark = SparkSession.builder \
.master("local[4]") \
.appName(cls.__name__) \
.config(
"spark.sql.extensions",
"org.apache.spark.sql.MyExtensions") \
.getOrCreate()

@classmethod
def tearDownClass(cls):
cls.spark.stop()

def test_use_custom_class_for_extensions(self):
self.assertTrue(
self.spark._jsparkSession.sessionState().planner().strategies().contains(
self.spark._jvm.org.apache.spark.sql.MySparkStrategy(self.spark._jsparkSession)),
"MySparkStrategy not found in active planner strategies")
self.assertTrue(
self.spark._jsparkSession.sessionState().analyzer().extendedResolutionRules().contains(
self.spark._jvm.org.apache.spark.sql.MyRule(self.spark._jsparkSession)),
"MyRule not found in extended resolution rules")


class SparkSessionTests(PySparkTestCase):

# This test is separate because it's closely related with session's start and stop.
Expand Down
56 changes: 38 additions & 18 deletions sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,17 @@ class SparkSession private(
// The call site where this SparkSession was constructed.
private val creationSite: CallSite = Utils.getCallSite()

/**
* Constructor used in Pyspark. Contains explicit application of Spark Session Extensions
* which otherwise only occurs during getOrCreate. We cannot add this to the default constructor
* since that would cause every new session to reinvoke Spark Session Extensions on the currently
* running extensions.
*/
private[sql] def this(sc: SparkContext) {
this(sc, None, None, new SparkSessionExtensions)
this(sc, None, None,
SparkSession.applyExtensions(
sc.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS),
new SparkSessionExtensions))
}

sparkContext.assertNotStopped()
Expand Down Expand Up @@ -935,23 +944,9 @@ object SparkSession extends Logging {
// Do not update `SparkConf` for existing `SparkContext`, as it's shared by all sessions.
}

// Initialize extensions if the user has defined a configurator class.
val extensionConfOption = sparkContext.conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS)
if (extensionConfOption.isDefined) {
val extensionConfClassName = extensionConfOption.get
try {
val extensionConfClass = Utils.classForName(extensionConfClassName)
val extensionConf = extensionConfClass.newInstance()
.asInstanceOf[SparkSessionExtensions => Unit]
extensionConf(extensions)
} catch {
// Ignore the error if we cannot find the class or when the class has the wrong type.
case e @ (_: ClassCastException |
_: ClassNotFoundException |
_: NoClassDefFoundError) =>
logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e)
}
}
applyExtensions(
sparkContext.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS),
extensions)

session = new SparkSession(sparkContext, None, None, extensions)
options.foreach { case (k, v) => session.initialSessionOptions.put(k, v) }
Expand Down Expand Up @@ -1136,4 +1131,29 @@ object SparkSession extends Logging {
SparkSession.clearDefaultSession()
}
}

/**
* Initialize extensions for given extension classname. This class will be applied to the
* extensions passed into this function.
*/
private def applyExtensions(
extensionOption: Option[String],
extensions: SparkSessionExtensions): SparkSessionExtensions = {
if (extensionOption.isDefined) {
val extensionConfClassName = extensionOption.get
try {
val extensionConfClass = Utils.classForName(extensionConfClassName)
val extensionConf = extensionConfClass.newInstance()
.asInstanceOf[SparkSessionExtensions => Unit]
extensionConf(extensions)
} catch {
// Ignore the error if we cannot find the class or when the class has the wrong type.
case e@(_: ClassCastException |
_: ClassNotFoundException |
_: NoClassDefFoundError) =>
logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e)
}
}
extensions
}
}