diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index a5e287257731..079af8c05705 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -156,7 +156,7 @@ def getOrCreate(self): default. >>> s1 = SparkSession.builder.config("k1", "v1").getOrCreate() - >>> s1.conf.get("k1") == s1.sparkContext.getConf().get("k1") == "v1" + >>> s1.conf.get("k1") == "v1" True In case an existing SparkSession is returned, the config options specified @@ -179,19 +179,13 @@ def getOrCreate(self): sparkConf = SparkConf() for key, value in self._options.items(): sparkConf.set(key, value) - sc = SparkContext.getOrCreate(sparkConf) # This SparkContext may be an existing one. - for key, value in self._options.items(): - # we need to propagate the confs - # before we create the SparkSession. Otherwise, confs like - # warehouse path and metastore url will not be set correctly ( - # these confs cannot be changed once the SparkSession is created). - sc._conf.set(key, value) + sc = SparkContext.getOrCreate(sparkConf) + # Do not update `SparkConf` for existing `SparkContext`, as it's shared + # by all sessions. session = SparkSession(sc) for key, value in self._options.items(): session._jsparkSession.sessionState().conf().setConfString(key, value) - for key, value in self._options.items(): - session.sparkContext._conf.set(key, value) return session builder = Builder() diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 74642d46d1cd..64a7ceb3fea9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -80,7 +80,7 @@ _have_pyarrow = _pyarrow_requirement_message is None _test_compiled = _test_not_compiled_message is None -from pyspark import SparkContext +from pyspark import SparkConf, SparkContext from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row from pyspark.sql.types import * from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier @@ -283,6 +283,50 @@ def test_invalid_create_row(self): self.assertRaises(ValueError, lambda: row_class(1, 2, 3)) +class SparkSessionBuilderTests(unittest.TestCase): + + def test_create_spark_context_first_then_spark_session(self): + sc = None + session = None + try: + conf = SparkConf().set("key1", "value1") + sc = SparkContext('local[4]', "SessionBuilderTests", conf=conf) + session = SparkSession.builder.config("key2", "value2").getOrCreate() + + self.assertEqual(session.conf.get("key1"), "value1") + self.assertEqual(session.conf.get("key2"), "value2") + self.assertEqual(session.sparkContext, sc) + + self.assertFalse(sc.getConf().contains("key2")) + self.assertEqual(sc.getConf().get("key1"), "value1") + finally: + if session is not None: + session.stop() + if sc is not None: + sc.stop() + + def test_another_spark_session(self): + session1 = None + session2 = None + try: + session1 = SparkSession.builder.config("key1", "value1").getOrCreate() + session2 = SparkSession.builder.config("key2", "value2").getOrCreate() + + self.assertEqual(session1.conf.get("key1"), "value1") + self.assertEqual(session2.conf.get("key1"), "value1") + self.assertEqual(session1.conf.get("key2"), "value2") + self.assertEqual(session2.conf.get("key2"), "value2") + self.assertEqual(session1.sparkContext, session2.sparkContext) + + self.assertEqual(session1.sparkContext.getConf().get("key1"), "value1") + self.assertFalse(session1.sparkContext.getConf().contains("key2")) + finally: + if session1 is not None: + session1.stop() + if session2 is not None: + session2.stop() + + class SQLTests(ReusedSQLTestCase): @classmethod