Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def __init__(self, sparkContext, jsparkSession=None):
jsparkSession = self._jvm.SparkSession.getDefaultSession().get()
else:
jsparkSession = self._jvm.SparkSession(self._jsc.sc())

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh haha, let's get rid of this change

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm addicted to whitespace apparently

self._jsparkSession = jsparkSession
self._jwrapped = self._jsparkSession.sqlContext()
self._wrapped = SQLContext(self._sc, self, self._jwrapped)
Expand Down
42 changes: 42 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3563,6 +3563,48 @@ def test_query_execution_listener_on_collect_with_arrow(self):
"The callback from the query execution listener should be called after 'toPandas'")


class SparkExtensionsTest(unittest.TestCase, SQLTestUtils):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think SQLTestUtils is not needed.

# These tests are separate because it uses 'spark.sql.extensions' which is
# static and immutable. This can't be set or unset, for example, via `spark.conf`.

@classmethod
def setUpClass(cls):
import glob
from pyspark.find_spark_home import _find_spark_home

SPARK_HOME = _find_spark_home()
filename_pattern = (
"sql/core/target/scala-*/test-classes/org/apache/spark/sql/"
"SparkSessionExtensionSuite.class")
if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)):
raise unittest.SkipTest(
"'org.apache.spark.sql.SparkSessionExtensionSuite' is not "
"available. Will skip the related tests.")

# Note that 'spark.sql.extensions' is a static immutable configuration.
cls.spark = SparkSession.builder \
.master("local[4]") \
.appName(cls.__name__) \
.config(
"spark.sql.extensions",
"org.apache.spark.sql.MyExtensions") \
.getOrCreate()

@classmethod
def tearDownClass(cls):
cls.spark.stop()

def test_use_custom_class_for_extensions(self):
self.assertTrue(
self.spark._jsparkSession.sessionState().planner().strategies().contains(
self.spark._jvm.org.apache.spark.sql.MySparkStrategy(self.spark._jsparkSession)),
"MySparkStrategy not found in active planner strategies")
self.assertTrue(
self.spark._jsparkSession.sessionState().analyzer().extendedResolutionRules().contains(
self.spark._jvm.org.apache.spark.sql.MyRule(self.spark._jsparkSession)),
"MyRule not found in extended resolution rules")


class SparkSessionTests(PySparkTestCase):

# This test is separate because it's closely related with session's start and stop.
Expand Down
42 changes: 25 additions & 17 deletions sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class SparkSession private(

private[sql] def this(sc: SparkContext) {
this(sc, None, None, new SparkSessionExtensions)
SparkSession.applyExtensionsFromConf(sc.getConf, this.extensions)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add some comments why this is only here in this constructor. It might look weird why this constructor specifically requires to run applyExtensionsFromConf alone.

}

sparkContext.assertNotStopped()
Expand Down Expand Up @@ -935,23 +936,7 @@ object SparkSession extends Logging {
// Do not update `SparkConf` for existing `SparkContext`, as it's shared by all sessions.
}

// Initialize extensions if the user has defined a configurator class.
val extensionConfOption = sparkContext.conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS)
if (extensionConfOption.isDefined) {
val extensionConfClassName = extensionConfOption.get
try {
val extensionConfClass = Utils.classForName(extensionConfClassName)
val extensionConf = extensionConfClass.newInstance()
.asInstanceOf[SparkSessionExtensions => Unit]
extensionConf(extensions)
} catch {
// Ignore the error if we cannot find the class or when the class has the wrong type.
case e @ (_: ClassCastException |
_: ClassNotFoundException |
_: NoClassDefFoundError) =>
logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e)
}
}
applyExtensionsFromConf(sparkContext.conf, extensions)

session = new SparkSession(sparkContext, None, None, extensions)
options.foreach { case (k, v) => session.initialSessionOptions.put(k, v) }
Expand Down Expand Up @@ -1136,4 +1121,27 @@ object SparkSession extends Logging {
SparkSession.clearDefaultSession()
}
}

/**
* Initialize extensions if the user has defined a configurator class in their SparkConf.
* This class will be applied to the extensions passed into this function.
*/
private[sql] def applyExtensionsFromConf(conf: SparkConf, extensions: SparkSessionExtensions) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make it private

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about returning SparkSessionExtensions from this method, and modifying the secondary constructor of SparkSession as:

private[sql] def this(sc: SparkContext) {
  this(sc, None, None,
    SparkSession.applyExtensionsFromConf(sc.getConf, new SparkSessionExtensions))
}

I'm a little worried whether the order we apply extensions might affect.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thoughts, we could move the method call to the top of the default constructor?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Default constructor of SparkSession?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this and was worried then about multiple invocations of the extensions
Once every time the SparkSession is cloned

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's difficult here since I'm attempting to cause the least change in behavior for the old code paths :(

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am always a little nervous about having functions return objects they take in as parameters and then modify. Gives an impression to me that they are stateless. If you think that this is clearer I can make the change.

Copy link
Member

@ueshin ueshin Oct 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, but in that case, we need to ensure that no injection of extensions is used in the default constructor to avoid initializing without injections from the conf.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eh .. I think it's okay to have a function and returns that updated extensions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually either way looks okay.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated with replacement then :)

val extensionConfOption = conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can even only pass conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS) as its argument instead of SparkConf, and name it applyExtensions.

if (extensionConfOption.isDefined) {
val extensionConfClassName = extensionConfOption.get
try {
val extensionConfClass = Utils.classForName(extensionConfClassName)
val extensionConf = extensionConfClass.newInstance()
.asInstanceOf[SparkSessionExtensions => Unit]
extensionConf(extensions)
} catch {
// Ignore the error if we cannot find the class or when the class has the wrong type.
case e@(_: ClassCastException |
_: ClassNotFoundException |
_: NoClassDefFoundError) =>
logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e)
}
}
}
}