Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def getOrCreate(self):
default.

>>> s1 = SparkSession.builder.config("k1", "v1").getOrCreate()
>>> s1.conf.get("k1") == s1.sparkContext.getConf().get("k1") == "v1"
>>> s1.conf.get("k1") == "v1"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ueshin Could we also update the migration guide about this change?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, we might have to put the behaviour changes by #18536 together to the migration guide as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do it together.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Submitted a pr to update the migration guide #22682.

True

In case an existing SparkSession is returned, the config options specified
Expand All @@ -179,19 +179,13 @@ def getOrCreate(self):
sparkConf = SparkConf()
for key, value in self._options.items():
sparkConf.set(key, value)
sc = SparkContext.getOrCreate(sparkConf)
# This SparkContext may be an existing one.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny nit: can we move this comment above sc = ...

for key, value in self._options.items():
# we need to propagate the confs
# before we create the SparkSession. Otherwise, confs like
# warehouse path and metastore url will not be set correctly (
# these confs cannot be changed once the SparkSession is created).
sc._conf.set(key, value)
sc = SparkContext.getOrCreate(sparkConf)
# Do not update `SparkConf` for existing `SparkContext`, as it's shared
# by all sessions.
session = SparkSession(sc)
for key, value in self._options.items():
session._jsparkSession.sessionState().conf().setConfString(key, value)
for key, value in self._options.items():
session.sparkContext._conf.set(key, value)
return session

builder = Builder()
Expand Down
46 changes: 45 additions & 1 deletion python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
_have_pyarrow = _pyarrow_requirement_message is None
_test_compiled = _test_not_compiled_message is None

from pyspark import SparkContext
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
from pyspark.sql.types import *
from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier
Expand Down Expand Up @@ -283,6 +283,50 @@ def test_invalid_create_row(self):
self.assertRaises(ValueError, lambda: row_class(1, 2, 3))


class SparkSessionBuilderTests(unittest.TestCase):

def test_create_spark_context_first_then_spark_session(self):
sc = None
session = None
try:
conf = SparkConf().set("key1", "value1")
sc = SparkContext('local[4]', "SessionBuilderTests", conf=conf)
session = SparkSession.builder.config("key2", "value2").getOrCreate()

self.assertEqual(session.conf.get("key1"), "value1")
self.assertEqual(session.conf.get("key2"), "value2")
self.assertEqual(session.sparkContext, sc)

self.assertFalse(sc.getConf().contains("key2"))
self.assertEqual(sc.getConf().get("key1"), "value1")
finally:
if session is not None:
session.stop()
if sc is not None:
sc.stop()

def test_another_spark_session(self):
session1 = None
session2 = None
try:
session1 = SparkSession.builder.config("key1", "value1").getOrCreate()
session2 = SparkSession.builder.config("key2", "value2").getOrCreate()

self.assertEqual(session1.conf.get("key1"), "value1")
self.assertEqual(session2.conf.get("key1"), "value1")
self.assertEqual(session1.conf.get("key2"), "value2")
self.assertEqual(session2.conf.get("key2"), "value2")
self.assertEqual(session1.sparkContext, session2.sparkContext)

self.assertEqual(session1.sparkContext.getConf().get("key1"), "value1")
self.assertFalse(session1.sparkContext.getConf().contains("key2"))
finally:
if session1 is not None:
session1.stop()
if session2 is not None:
session2.stop()


class SQLTests(ReusedSQLTestCase):

@classmethod
Expand Down