Skip to content

Commit 0f167db

Browse files
committed
Clean up tests
1 parent 2740c63 commit 0f167db

File tree

2 files changed

+96
-77
lines changed

2 files changed

+96
-77
lines changed

sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala

Lines changed: 89 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -20,92 +20,105 @@ package org.apache.spark.sql
2020
import org.scalatest.BeforeAndAfterAll
2121
import org.scalatest.BeforeAndAfterEach
2222

23-
import org.apache.spark.SparkContext
2423
import org.apache.spark.SparkFunSuite
2524
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2625
import org.apache.spark.sql.catalyst.rules.Rule
2726

2827
class SessionStateSuite extends SparkFunSuite
2928
with BeforeAndAfterEach with BeforeAndAfterAll {
3029

30+
/**
31+
* A shared SparkSession for all tests in this suite. Make sure you reset any changes to this
32+
* session as this is a singleton HiveSparkSession in HiveSessionStateSuite and it's shared
33+
* with all Hive test suites.
34+
*/
3135
protected var activeSession: SparkSession = _
32-
protected var sparkContext: SparkContext = null
3336

3437
override def beforeAll(): Unit = {
35-
sparkContext = SparkSession.builder().master("local").getOrCreate().sparkContext
36-
}
37-
38-
protected def createSession(): Unit = {
39-
activeSession =
40-
SparkSession.builder().master("local").sparkContext(sparkContext).getOrCreate()
41-
}
42-
43-
override def beforeEach(): Unit = {
44-
createSession()
38+
activeSession = SparkSession.builder().master("local").getOrCreate()
4539
}
4640

4741
override def afterAll(): Unit = {
48-
if (sparkContext != null) {
49-
sparkContext.stop()
42+
if (activeSession != null) {
43+
activeSession.stop()
44+
activeSession = null
5045
}
46+
super.afterAll()
5147
}
5248

5349
test("fork new session and inherit RuntimeConfig options") {
5450
val key = "spark-config-clone"
5551
activeSession.conf.set(key, "active")
56-
57-
// inheritance
58-
val forkedSession = activeSession.cloneSession()
59-
assert(forkedSession ne activeSession)
60-
assert(forkedSession.conf ne activeSession.conf)
61-
assert(forkedSession.conf.get(key) == "active")
62-
63-
// independence
64-
forkedSession.conf.set(key, "forked")
65-
assert(activeSession.conf.get(key) == "active")
66-
activeSession.conf.set(key, "dontcopyme")
67-
assert(forkedSession.conf.get(key) == "forked")
52+
try {
53+
// inheritance
54+
val forkedSession = activeSession.cloneSession()
55+
assert(forkedSession ne activeSession)
56+
assert(forkedSession.conf ne activeSession.conf)
57+
assert(forkedSession.conf.get(key) == "active")
58+
59+
// independence
60+
forkedSession.conf.set(key, "forked")
61+
assert(activeSession.conf.get(key) == "active")
62+
activeSession.conf.set(key, "dontcopyme")
63+
assert(forkedSession.conf.get(key) == "forked")
64+
} finally {
65+
activeSession.conf.unset(key)
66+
}
6867
}
6968

7069
test("fork new session and inherit function registry and udf") {
71-
activeSession.udf.register("strlenScala", (_: String).length + (_: Int))
72-
val forkedSession = activeSession.cloneSession()
73-
74-
// inheritance
75-
assert(forkedSession ne activeSession)
76-
assert(forkedSession.sessionState.functionRegistry ne
77-
activeSession.sessionState.functionRegistry)
78-
assert(forkedSession.sessionState.functionRegistry.lookupFunction("strlenScala").nonEmpty)
79-
80-
// independence
81-
forkedSession.sessionState.functionRegistry.dropFunction("strlenScala")
82-
assert(activeSession.sessionState.functionRegistry.lookupFunction("strlenScala").nonEmpty)
83-
activeSession.udf.register("addone", (_: Int) + 1)
84-
assert(forkedSession.sessionState.functionRegistry.lookupFunction("addone").isEmpty)
70+
val testFuncName1 = "strlenScala"
71+
val testFuncName2 = "addone"
72+
try {
73+
activeSession.udf.register(testFuncName1, (_: String).length + (_: Int))
74+
val forkedSession = activeSession.cloneSession()
75+
76+
// inheritance
77+
assert(forkedSession ne activeSession)
78+
assert(forkedSession.sessionState.functionRegistry ne
79+
activeSession.sessionState.functionRegistry)
80+
assert(forkedSession.sessionState.functionRegistry.lookupFunction(testFuncName1).nonEmpty)
81+
82+
// independence
83+
forkedSession.sessionState.functionRegistry.dropFunction(testFuncName1)
84+
assert(activeSession.sessionState.functionRegistry.lookupFunction(testFuncName1).nonEmpty)
85+
activeSession.udf.register(testFuncName2, (_: Int) + 1)
86+
assert(forkedSession.sessionState.functionRegistry.lookupFunction(testFuncName2).isEmpty)
87+
} finally {
88+
activeSession.sessionState.functionRegistry.dropFunction(testFuncName1)
89+
activeSession.sessionState.functionRegistry.dropFunction(testFuncName2)
90+
}
8591
}
8692

8793
test("fork new session and inherit experimental methods") {
88-
object DummyRule1 extends Rule[LogicalPlan] {
89-
def apply(p: LogicalPlan): LogicalPlan = p
90-
}
91-
object DummyRule2 extends Rule[LogicalPlan] {
92-
def apply(p: LogicalPlan): LogicalPlan = p
94+
val originalExtraOptimizations = activeSession.experimental.extraOptimizations
95+
val originalExtraStrategies = activeSession.experimental.extraStrategies
96+
try {
97+
object DummyRule1 extends Rule[LogicalPlan] {
98+
def apply(p: LogicalPlan): LogicalPlan = p
99+
}
100+
object DummyRule2 extends Rule[LogicalPlan] {
101+
def apply(p: LogicalPlan): LogicalPlan = p
102+
}
103+
val optimizations = List(DummyRule1, DummyRule2)
104+
activeSession.experimental.extraOptimizations = optimizations
105+
val forkedSession = activeSession.cloneSession()
106+
107+
// inheritance
108+
assert(forkedSession ne activeSession)
109+
assert(forkedSession.experimental ne activeSession.experimental)
110+
assert(forkedSession.experimental.extraOptimizations.toSet ==
111+
activeSession.experimental.extraOptimizations.toSet)
112+
113+
// independence
114+
forkedSession.experimental.extraOptimizations = List(DummyRule2)
115+
assert(activeSession.experimental.extraOptimizations == optimizations)
116+
activeSession.experimental.extraOptimizations = List(DummyRule1)
117+
assert(forkedSession.experimental.extraOptimizations == List(DummyRule2))
118+
} finally {
119+
activeSession.experimental.extraOptimizations = originalExtraOptimizations
120+
activeSession.experimental.extraStrategies = originalExtraStrategies
93121
}
94-
val optimizations = List(DummyRule1, DummyRule2)
95-
activeSession.experimental.extraOptimizations = optimizations
96-
val forkedSession = activeSession.cloneSession()
97-
98-
// inheritance
99-
assert(forkedSession ne activeSession)
100-
assert(forkedSession.experimental ne activeSession.experimental)
101-
assert(forkedSession.experimental.extraOptimizations.toSet ==
102-
activeSession.experimental.extraOptimizations.toSet)
103-
104-
// independence
105-
forkedSession.experimental.extraOptimizations = List(DummyRule2)
106-
assert(activeSession.experimental.extraOptimizations == optimizations)
107-
activeSession.experimental.extraOptimizations = List(DummyRule1)
108-
assert(forkedSession.experimental.extraOptimizations == List(DummyRule2))
109122
}
110123

111124
test("fork new sessions and run query on inherited table") {
@@ -119,19 +132,26 @@ class SessionStateSuite extends SparkFunSuite
119132
Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil)
120133
}
121134

122-
implicit val enc = Encoders.tuple(Encoders.scalaInt, Encoders.STRING)
135+
val spark = activeSession
136+
// Cannot use `import activeSession.implicits._` due to the compiler limitation.
137+
import spark.implicits._
138+
123139
activeSession
124140
.createDataset[(Int, String)](Seq(1, 2, 3).map(i => (i, i.toString)))
125141
.toDF("int", "str")
126142
.createOrReplaceTempView("df")
127-
checkTableExists(activeSession)
128-
129-
val forkedSession = activeSession.cloneSession()
130-
assert(forkedSession ne activeSession)
131-
assert(forkedSession.sessionState ne activeSession.sessionState)
132-
checkTableExists(forkedSession)
133-
checkTableExists(activeSession.cloneSession()) // ability to clone multiple times
134-
checkTableExists(forkedSession.cloneSession()) // clone of clone
143+
try {
144+
checkTableExists(activeSession)
145+
146+
val forkedSession = activeSession.cloneSession()
147+
assert(forkedSession ne activeSession)
148+
assert(forkedSession.sessionState ne activeSession.sessionState)
149+
checkTableExists(forkedSession)
150+
checkTableExists(activeSession.cloneSession()) // ability to clone multiple times
151+
checkTableExists(forkedSession.cloneSession()) // clone of clone
152+
} finally {
153+
activeSession.sql("drop table df")
154+
}
135155
}
136156

137157
test("fork new session and inherit reference to SharedState") {

sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSessionStateSuite.scala

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,14 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton
2828
class HiveSessionStateSuite extends SessionStateSuite
2929
with TestHiveSingleton with BeforeAndAfterEach {
3030

31-
override def afterEach(): Unit = {}
32-
33-
override def beforeAll(): Unit = {}
34-
35-
override def afterAll(): Unit = {
36-
hiveContext.reset()
31+
override def beforeAll(): Unit = {
32+
// Reuse the singleton session
33+
activeSession = spark
3734
}
3835

39-
override def createSession(): Unit = {
40-
activeSession = spark.newSession() // TestHiveSparkSession from TestHiveSingleton
36+
override def afterAll(): Unit = {
37+
// Set activeSession to null to avoid stopping the singleton session
38+
activeSession = null
39+
super.afterAll()
4140
}
4241
}

0 commit comments

Comments
 (0)