Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,8 @@ def __init__(self, sparkContext, jhiveContext=None):
"SparkSession.builder.enableHiveSupport().getOrCreate() instead.",
DeprecationWarning)
if jhiveContext is None:
sparkSession = SparkSession.builder.enableHiveSupport().getOrCreate()
sparkContext._conf.set("spark.sql.catalogImplementation", "hive")
sparkSession = SparkSession.builder._sparkContext(sparkContext).getOrCreate()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this change?

else:
sparkSession = SparkSession(sparkContext, jhiveContext.sparkSession())
SQLContext.__init__(self, sparkContext, sparkSession, jhiveContext)
Expand Down
30 changes: 16 additions & 14 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class Builder(object):

_lock = RLock()
_options = {}
_sc = None

@since(2.0)
def config(self, key=None, value=None, conf=None):
Expand Down Expand Up @@ -139,6 +140,10 @@ def enableHiveSupport(self):
"""
return self.config("spark.sql.catalogImplementation", "hive")

def _sparkContext(self, sc):
self._sc = sc
return self

@since(2.0)
def getOrCreate(self):
"""Gets an existing :class:`SparkSession` or, if there is no existing one, creates a
Expand All @@ -150,7 +155,7 @@ def getOrCreate(self):
default.

>>> s1 = SparkSession.builder.config("k1", "v1").getOrCreate()
>>> s1.conf.get("k1") == s1.sparkContext.getConf().get("k1") == "v1"
>>> s1.conf.get("k1") == "v1"

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ueshin Could we also update the migration guide about this change?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, we might have to put the behaviour changes by #18536 together to the migration guide as well.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do it together.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Submitted a pr to update the migration guide #22682.

True

In case an existing SparkSession is returned, the config options specified
Expand All @@ -167,22 +172,19 @@ def getOrCreate(self):
from pyspark.conf import SparkConf
session = SparkSession._instantiatedSession
if session is None or session._sc._jsc is None:
sparkConf = SparkConf()
for key, value in self._options.items():
sparkConf.set(key, value)
sc = SparkContext.getOrCreate(sparkConf)
# This SparkContext may be an existing one.
for key, value in self._options.items():
# we need to propagate the confs
# before we create the SparkSession. Otherwise, confs like
# warehouse path and metastore url will not be set correctly (
# these confs cannot be changed once the SparkSession is created).
sc._conf.set(key, value)
if self._sc is not None:
sc = self._sc
else:
sparkConf = SparkConf()
for key, value in self._options.items():
sparkConf.set(key, value)
sc = SparkContext.getOrCreate(sparkConf)
# This SparkContext may be an existing one.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny nit: can we move this comment above sc = ...

# Do not update `SparkConf` for existing `SparkContext`, as it's shared
# by all sessions.
session = SparkSession(sc)
for key, value in self._options.items():
session._jsparkSession.sessionState().conf().setConfString(key, value)
for key, value in self._options.items():
session.sparkContext._conf.set(key, value)
return session

builder = Builder()
Expand Down
46 changes: 45 additions & 1 deletion python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
_have_pyarrow = _pyarrow_requirement_message is None
_test_compiled = _test_not_compiled_message is None

from pyspark import SparkContext
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
from pyspark.sql.types import *
from pyspark.sql.types import UserDefinedType, _infer_type, _make_type_verifier
Expand Down Expand Up @@ -283,6 +283,50 @@ def test_invalid_create_row(self):
self.assertRaises(ValueError, lambda: row_class(1, 2, 3))


class SparkSessionBuilderTests(unittest.TestCase):

def test_create_spark_context_first_then_spark_session(self):
sc = None
session = None
try:
conf = SparkConf().set("key1", "value1")
sc = SparkContext('local[4]', "SessionBuilderTests", conf=conf)
session = SparkSession.builder.config("key2", "value2").getOrCreate()

self.assertEqual(session.conf.get("key1"), "value1")
self.assertEqual(session.conf.get("key2"), "value2")
self.assertEqual(session.sparkContext, sc)

self.assertFalse(sc.getConf().contains("key2"))
self.assertEqual(sc.getConf().get("key1"), "value1")
finally:
if session is not None:
session.stop()
if sc is not None:
sc.stop()

def test_another_spark_session(self):
session1 = None
session2 = None
try:
session1 = SparkSession.builder.config("key1", "value1").getOrCreate()
session2 = SparkSession.builder.config("key2", "value2").getOrCreate()

self.assertEqual(session1.conf.get("key1"), "value1")
self.assertEqual(session2.conf.get("key1"), "value1")
self.assertEqual(session1.conf.get("key2"), "value2")
self.assertEqual(session2.conf.get("key2"), "value2")
self.assertEqual(session1.sparkContext, session2.sparkContext)

self.assertEqual(session1.sparkContext.getConf().get("key1"), "value1")
self.assertFalse(session1.sparkContext.getConf().contains("key2"))
finally:
if session1 is not None:
session1.stop()
if session2 is not None:
session2.stop()


class SQLTests(ReusedSQLTestCase):

@classmethod
Expand Down