Skip to content

Commit 1c7bad8

Browse files
committed
Make sorting of answers explicit in SparkPlanTest.checkAnswer().
1 parent b81a920 commit 1c7bad8

File tree

4 files changed

+61
-26
lines changed

4 files changed

+61
-26
lines changed

sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,14 @@ class SortSuite extends SparkPlanTest {
3333

3434
checkAnswer(
3535
input.toDF("a", "b", "c"),
36-
ExternalSort('a.asc :: 'b.asc :: Nil, global = false, _: SparkPlan),
37-
input.sorted)
36+
ExternalSort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan),
37+
input.sortBy(t => (t._1, t._2)),
38+
sortAnswers = false)
3839

3940
checkAnswer(
4041
input.toDF("a", "b", "c"),
41-
ExternalSort('b.asc :: 'a.asc :: Nil, global = false, _: SparkPlan),
42-
input.sortBy(t => (t._2, t._1)))
42+
ExternalSort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan),
43+
input.sortBy(t => (t._2, t._1)),
44+
sortAnswers = false)
4345
}
4446
}

sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,19 @@ class SparkPlanTest extends SparkFunSuite {
4646
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
4747
* the physical operator that's being tested.
4848
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
49+
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
50+
* to being compared.
4951
*/
5052
protected def checkAnswer(
5153
input: DataFrame,
5254
planFunction: SparkPlan => SparkPlan,
53-
expectedAnswer: Seq[Row]): Unit = {
54-
checkAnswer(input :: Nil, (plans: Seq[SparkPlan]) => planFunction(plans.head), expectedAnswer)
55+
expectedAnswer: Seq[Row],
56+
sortAnswers: Boolean): Unit = {
57+
checkAnswer(
58+
input :: Nil,
59+
(plans: Seq[SparkPlan]) => planFunction(plans.head),
60+
expectedAnswer,
61+
sortAnswers)
5562
}
5663

5764
/**
@@ -61,14 +68,20 @@ class SparkPlanTest extends SparkFunSuite {
6168
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
6269
* the physical operator that's being tested.
6370
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
71+
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
72+
* to being compared.
6473
*/
6574
protected def checkAnswer(
6675
left: DataFrame,
6776
right: DataFrame,
6877
planFunction: (SparkPlan, SparkPlan) => SparkPlan,
69-
expectedAnswer: Seq[Row]): Unit = {
70-
checkAnswer(left :: right :: Nil,
71-
(plans: Seq[SparkPlan]) => planFunction(plans(0), plans(1)), expectedAnswer)
78+
expectedAnswer: Seq[Row],
79+
sortAnswers: Boolean): Unit = {
80+
checkAnswer(
81+
left :: right :: Nil,
82+
(plans: Seq[SparkPlan]) => planFunction(plans(0), plans(1)),
83+
expectedAnswer,
84+
sortAnswers)
7285
}
7386

7487
/**
@@ -77,12 +90,15 @@ class SparkPlanTest extends SparkFunSuite {
7790
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
7891
* the physical operator that's being tested.
7992
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
93+
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
94+
* to being compared.
8095
*/
8196
protected def checkAnswer(
8297
input: Seq[DataFrame],
8398
planFunction: Seq[SparkPlan] => SparkPlan,
84-
expectedAnswer: Seq[Row]): Unit = {
85-
SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer) match {
99+
expectedAnswer: Seq[Row],
100+
sortAnswers: Boolean): Unit = {
101+
SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers) match {
86102
case Some(errorMessage) => fail(errorMessage)
87103
case None =>
88104
}
@@ -94,13 +110,16 @@ class SparkPlanTest extends SparkFunSuite {
94110
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
95111
* the physical operator that's being tested.
96112
* @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s.
113+
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
114+
* to being compared.
97115
*/
98116
protected def checkAnswer[A <: Product : TypeTag](
99117
input: DataFrame,
100118
planFunction: SparkPlan => SparkPlan,
101-
expectedAnswer: Seq[A]): Unit = {
119+
expectedAnswer: Seq[A],
120+
sortAnswers: Boolean): Unit = {
102121
val expectedRows = expectedAnswer.map(Row.fromTuple)
103-
checkAnswer(input, planFunction, expectedRows)
122+
checkAnswer(input, planFunction, expectedRows, sortAnswers)
104123
}
105124

106125
/**
@@ -110,14 +129,17 @@ class SparkPlanTest extends SparkFunSuite {
110129
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
111130
* the physical operator that's being tested.
112131
* @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s.
132+
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
133+
* to being compared.
113134
*/
114135
protected def checkAnswer[A <: Product : TypeTag](
115136
left: DataFrame,
116137
right: DataFrame,
117138
planFunction: (SparkPlan, SparkPlan) => SparkPlan,
118-
expectedAnswer: Seq[A]): Unit = {
139+
expectedAnswer: Seq[A],
140+
sortAnswers: Boolean): Unit = {
119141
val expectedRows = expectedAnswer.map(Row.fromTuple)
120-
checkAnswer(left, right, planFunction, expectedRows)
142+
checkAnswer(left, right, planFunction, expectedRows, sortAnswers)
121143
}
122144

123145
/**
@@ -126,13 +148,16 @@ class SparkPlanTest extends SparkFunSuite {
126148
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
127149
* the physical operator that's being tested.
128150
* @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s.
151+
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
152+
* to being compared.
129153
*/
130154
protected def checkAnswer[A <: Product : TypeTag](
131155
input: Seq[DataFrame],
132156
planFunction: Seq[SparkPlan] => SparkPlan,
133-
expectedAnswer: Seq[A]): Unit = {
157+
expectedAnswer: Seq[A],
158+
sortAnswers: Boolean): Unit = {
134159
val expectedRows = expectedAnswer.map(Row.fromTuple)
135-
checkAnswer(input, planFunction, expectedRows)
160+
checkAnswer(input, planFunction, expectedRows, sortAnswers)
136161
}
137162

138163
/**
@@ -231,11 +256,14 @@ object SparkPlanTest {
231256
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
232257
* the physical operator that's being tested.
233258
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
259+
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
260+
* to being compared.
234261
*/
235262
def checkAnswer(
236263
input: Seq[DataFrame],
237264
planFunction: Seq[SparkPlan] => SparkPlan,
238-
expectedAnswer: Seq[Row]): Option[String] = {
265+
expectedAnswer: Seq[Row],
266+
sortAnswers: Boolean): Option[String] = {
239267

240268
val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan))
241269

@@ -254,7 +282,7 @@ object SparkPlanTest {
254282
return Some(errorMessage)
255283
}
256284

257-
compareAnswers(sparkAnswer, expectedAnswer).map { errorMessage =>
285+
compareAnswers(sparkAnswer, expectedAnswer, sortAnswers).map { errorMessage =>
258286
s"""
259287
| Results do not match for Spark plan:
260288
| $outputPlan
@@ -266,7 +294,7 @@ object SparkPlanTest {
266294
private def compareAnswers(
267295
sparkAnswer: Seq[Row],
268296
expectedAnswer: Seq[Row],
269-
sort: Boolean = true): Option[String] = {
297+
sort: Boolean): Option[String] = {
270298
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
271299
// Converts data to types that we can do equality comparison using Scala collections.
272300
// For BigDecimal type, the Scala type has a better definition of equality test (similar to

sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll {
4040
// TODO: randomized spilling to ensure that merging is tested at least once for every data type.
4141
for (
4242
dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType);
43-
nullable <- Seq(false);
43+
nullable <- Seq(true, false);
4444
sortOrder <- Seq('a.asc :: Nil);
4545
randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable)
4646
) {

sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,17 @@ class OuterJoinSuite extends SparkPlanTest {
4747
(1, 2.0, null, null),
4848
(2, 1.0, 2, 3.0),
4949
(3, 3.0, null, null)
50-
))
50+
),
51+
sortAnswers = true)
5152

5253
checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
5354
ShuffledHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right),
5455
Seq(
5556
(2, 1.0, 2, 3.0),
5657
(null, null, 3, 2.0),
5758
(null, null, 4, 1.0)
58-
))
59+
),
60+
sortAnswers = true)
5961

6062
checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
6163
ShuffledHashOuterJoin(leftKeys, rightKeys, FullOuter, condition, left, right),
@@ -65,7 +67,8 @@ class OuterJoinSuite extends SparkPlanTest {
6567
(3, 3.0, null, null),
6668
(null, null, 3, 2.0),
6769
(null, null, 4, 1.0)
68-
))
70+
),
71+
sortAnswers = true)
6972
}
7073

7174
test("broadcast hash outer join") {
@@ -75,14 +78,16 @@ class OuterJoinSuite extends SparkPlanTest {
7578
(1, 2.0, null, null),
7679
(2, 1.0, 2, 3.0),
7780
(3, 3.0, null, null)
78-
))
81+
),
82+
sortAnswers = true)
7983

8084
checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
8185
BroadcastHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right),
8286
Seq(
8387
(2, 1.0, 2, 3.0),
8488
(null, null, 3, 2.0),
8589
(null, null, 4, 1.0)
86-
))
90+
),
91+
sortAnswers = true)
8792
}
8893
}

0 commit comments

Comments
 (0)