Skip to content

Commit 9beb78d

Browse files
committed
Add tests for forking new session with inherit config enabled. Update overloaded functions for Java bytecode compatibility.
1 parent 18ce1b8 commit 9beb78d

File tree

7 files changed

+125
-22
lines changed

7 files changed

+125
-22
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,10 +1181,10 @@ class SessionCatalog(
11811181
}
11821182

11831183
/**
1184-
* Get an identical copy of the `SessionCatalog`.
1185-
* The temporary tables and function registry are retained.
1186-
* The table relation cache will not be populated.
1187-
*/
1184+
* Get an identical copy of the `SessionCatalog`.
1185+
* The temporary tables and function registry are retained.
1186+
* The table relation cache will not be populated.
1187+
*/
11881188
override def clone: SessionCatalog = {
11891189
val catalog = new SessionCatalog(
11901190
externalCatalog,

sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,12 @@ class ExperimentalMethods private[sql]() {
4949
@volatile var extraOptimizations: Seq[Rule[LogicalPlan]] = Nil
5050

5151
/**
52-
* Get an identical copy of this `ExperimentalMethods` instance.
53-
* @note This is used when forking a `SparkSession`.
54-
* `clone` is provided here instead of implementing equivalent functionality
55-
* in `SparkSession.clone` since it needs to be updated
56-
* as the class `ExperimentalMethods` is extended or modified.
57-
*/
52+
* Get an identical copy of this `ExperimentalMethods` instance.
53+
* @note This is used when forking a `SparkSession`.
54+
* `clone` is provided here instead of implementing equivalent functionality
55+
* in `SparkSession.clone` since it needs to be updated
56+
* as the class `ExperimentalMethods` is extended or modified.
57+
*/
5858
override def clone: ExperimentalMethods = {
5959
def cloneSeq[T](seq: Seq[T]): Seq[T] = {
6060
val newSeq = new ListBuffer[T]

sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ import org.apache.spark.sql.internal.{CatalogImpl, SessionState, SharedState}
4444
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
4545
import org.apache.spark.sql.sources.BaseRelation
4646
import org.apache.spark.sql.streaming._
47-
import org.apache.spark.sql.types.{DataType, LongType, StructType}
47+
import org.apache.spark.sql.types.{DataType, StructType}
4848
import org.apache.spark.sql.util.ExecutionListenerManager
4949
import org.apache.spark.util.Utils
5050

@@ -202,8 +202,6 @@ class SparkSession private(
202202
/**
203203
* Start a new session with isolated SQL configurations, temporary tables, registered
204204
* functions are isolated, but sharing the underlying `SparkContext` and cached data.
205-
* If inherit is enabled, then SQL configurations, temporary tables, registered functions
206-
* are copied over from parent `SparkSession`.
207205
*
208206
* @note Other than the `SparkContext`, all shared state is initialized lazily.
209207
* This method will force the initialization of the shared state to ensure that parent
@@ -212,11 +210,25 @@ class SparkSession private(
212210
*
213211
* @since 2.0.0
214212
*/
215-
def newSession(inheritSessionState: Boolean = false): SparkSession = {
213+
def newSession(): SparkSession = {
214+
new SparkSession(sparkContext, Some(sharedState))
215+
}
216+
217+
/**
218+
* Start a new session, sharing the underlying `SparkContext` and cached data.
219+
* If inheritSessionState is enabled, then SQL configurations, temporary tables,
220+
* registered functions are copied over from parent `SparkSession`.
221+
*
222+
* @note Other than the `SparkContext`, all shared state is initialized lazily.
223+
* This method will force the initialization of the shared state to ensure that parent
224+
* and child sessions are set up with the same shared state. If the underlying catalog
225+
* implementation is Hive, this will initialize the metastore, which may take some time.
226+
*/
227+
def newSession(inheritSessionState: Boolean): SparkSession = {
216228
if (inheritSessionState) {
217229
new SparkSession(sparkContext, Some(sharedState), Some(sessionState.clone))
218230
} else {
219-
new SparkSession(sparkContext, Some(sharedState))
231+
newSession()
220232
}
221233
}
222234

sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@ import org.apache.spark.sql.util.ExecutionListenerManager
4242
*/
4343
private[sql] class SessionState(
4444
sparkSession: SparkSession,
45-
existingSessionState: Option[SessionState] = None) {
45+
existingSessionState: Option[SessionState]) {
46+
47+
private[sql] def this(sparkSession: SparkSession) = {
48+
this(sparkSession, None)
49+
}
4650

4751
// Note: These are all lazy vals because they depend on each other (e.g. conf) and we
4852
// want subclasses to override some of the fields. Otherwise, we would get a lot of NPEs.
@@ -198,8 +202,8 @@ private[sql] class SessionState(
198202
}
199203

200204
/**
201-
* Get an identical copy of the `SessionState`.
202-
*/
205+
* Get an identical copy of the `SessionState`.
206+
*/
203207
override def clone: SessionState = {
204208
new SessionState(sparkSession, Some(this))
205209
}

sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderSuite.scala

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
package org.apache.spark.sql
1919

2020
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
21+
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
22+
import org.apache.spark.sql.catalyst.rules.Rule
2123

2224
/**
2325
* Test cases for the builder pattern of [[SparkSession]].
@@ -123,4 +125,70 @@ class SparkSessionBuilderSuite extends SparkFunSuite {
123125
session.stop()
124126
}
125127
}
128+
129+
test("fork new session and inherit a copy of the session state") {
130+
val activeSession = SparkSession.builder().master("local").getOrCreate()
131+
val forkedSession = activeSession.newSession(inheritSessionState = true)
132+
133+
assert(forkedSession ne activeSession)
134+
assert(forkedSession.sessionState ne activeSession.sessionState)
135+
136+
forkedSession.stop()
137+
activeSession.stop()
138+
}
139+
140+
test("fork new session and inherit sql config options") {
141+
val activeSession = SparkSession
142+
.builder()
143+
.master("local")
144+
.config("spark-configb", "b")
145+
.getOrCreate()
146+
val forkedSession = activeSession.newSession(inheritSessionState = true)
147+
148+
assert(forkedSession ne activeSession)
149+
assert(forkedSession.conf ne activeSession.conf)
150+
assert(forkedSession.conf.get("spark-configb") == "b")
151+
152+
forkedSession.stop()
153+
activeSession.stop()
154+
}
155+
156+
test("fork new session and inherit function registry and udf") {
157+
val activeSession = SparkSession.builder().master("local").getOrCreate()
158+
activeSession.udf.register("strlenScala", (_: String).length + (_: Int))
159+
val forkedSession = activeSession.newSession(inheritSessionState = true)
160+
161+
assert(forkedSession ne activeSession)
162+
assert(forkedSession.sessionState.functionRegistry ne
163+
activeSession.sessionState.functionRegistry)
164+
assert(forkedSession.sessionState.functionRegistry.lookupFunction("strlenScala").nonEmpty)
165+
166+
forkedSession.stop()
167+
activeSession.stop()
168+
}
169+
170+
test("fork new session and inherit experimental methods") {
171+
object DummyRule1 extends Rule[LogicalPlan] {
172+
def apply(p: LogicalPlan): LogicalPlan = p
173+
}
174+
object DummyRule2 extends Rule[LogicalPlan] {
175+
def apply(p: LogicalPlan): LogicalPlan = p
176+
}
177+
val optimizations = List(DummyRule1, DummyRule2)
178+
179+
val activeSession = SparkSession.builder().master("local").getOrCreate()
180+
activeSession.experimental.extraOptimizations = optimizations
181+
182+
val forkedSession = activeSession.newSession(inheritSessionState = true)
183+
184+
assert(forkedSession ne activeSession)
185+
assert(forkedSession.experimental ne activeSession.experimental)
186+
assert(forkedSession.experimental.extraOptimizations ne
187+
activeSession.experimental.extraOptimizations)
188+
assert(forkedSession.experimental.extraOptimizations.toSet ==
189+
activeSession.experimental.extraOptimizations.toSet)
190+
191+
forkedSession.stop()
192+
activeSession.stop()
193+
}
126194
}

sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,21 @@ class CatalogSuite
493493
}
494494
}
495495

496+
test("clone SessionCatalog") {
497+
// need to test tempTables are cloned
498+
assert(spark.catalog.listTables().collect().isEmpty)
499+
500+
createTempTable("my_temp_table")
501+
assert(spark.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table"))
502+
503+
val forkedSession = spark.newSession(inheritSessionState = true)
504+
assert(forkedSession.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table"))
505+
506+
dropTable("my_temp_table")
507+
assert(spark.catalog.listTables().collect().map(_.name).toSet == Set())
508+
assert(forkedSession.catalog.listTables().collect().map(_.name).toSet == Set("my_temp_table"))
509+
}
510+
496511
// TODO: add tests for the rest of them
497512

498513
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,11 @@ import org.apache.spark.{SparkConf, SparkContext}
3232
import org.apache.spark.internal.Logging
3333
import org.apache.spark.sql.{SparkSession, SQLContext}
3434
import org.apache.spark.sql.catalyst.analysis._
35-
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
36-
import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
3735
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
3836
import org.apache.spark.sql.execution.QueryExecution
3937
import org.apache.spark.sql.execution.command.CacheTableCommand
4038
import org.apache.spark.sql.hive._
41-
import org.apache.spark.sql.internal.{SharedState, SQLConf}
39+
import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf}
4240
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
4341
import org.apache.spark.util.{ShutdownHookManager, Utils}
4442

@@ -115,16 +113,22 @@ class TestHiveContext(
115113
private[hive] class TestHiveSparkSession(
116114
@transient private val sc: SparkContext,
117115
@transient private val existingSharedState: Option[SharedState],
116+
existingSessionState: Option[SessionState],
118117
private val loadTestTables: Boolean)
119118
extends SparkSession(sc) with Logging { self =>
120119

121120
def this(sc: SparkContext, loadTestTables: Boolean) {
122121
this(
123122
sc,
124123
existingSharedState = None,
124+
existingSessionState = None,
125125
loadTestTables)
126126
}
127127

128+
def this(sc: SparkContext, existingSharedState: Option[SharedState], loadTestTables: Boolean) {
129+
this(sc, existingSharedState, existingSessionState = None, loadTestTables)
130+
}
131+
128132
{ // set the metastore temporary configuration
129133
val metastoreTempConf = HiveUtils.newTemporaryConfiguration(useInMemoryDerby = false) ++ Map(
130134
ConfVars.METASTORE_INTEGER_JDO_PUSHDOWN.varname -> "true",
@@ -151,7 +155,7 @@ private[hive] class TestHiveSparkSession(
151155
new TestHiveSessionState(self)
152156

153157
override def newSession(): TestHiveSparkSession = {
154-
new TestHiveSparkSession(sc, Some(sharedState), loadTestTables)
158+
new TestHiveSparkSession(sc, Some(sharedState), None, loadTestTables)
155159
}
156160

157161
private var cacheTables: Boolean = false

0 commit comments

Comments
 (0)