1717
1818package org .apache .spark .sql .execution
1919
20- import scala .language .implicitConversions
21- import scala .reflect .runtime .universe .TypeTag
22- import scala .util .control .NonFatal
23-
2420import org .apache .spark .SparkFunSuite
25-
2621import org .apache .spark .sql .catalyst .analysis .UnresolvedAttribute
2722import org .apache .spark .sql .catalyst .expressions .BoundReference
2823import org .apache .spark .sql .catalyst .util ._
29-
3024import org .apache .spark .sql .test .TestSQLContext
31- import org .apache .spark .sql .{DataFrameHolder , Row , DataFrame }
25+ import org .apache .spark .sql .{DataFrame , DataFrameHolder , Row }
26+
27+ import scala .language .implicitConversions
28+ import scala .reflect .runtime .universe .TypeTag
29+ import scala .util .control .NonFatal
3230
3331/**
3432 * Base class for writing tests for individual physical operators. For an example of how this
@@ -77,13 +75,93 @@ class SparkPlanTest extends SparkFunSuite {
7775 case None =>
7876 }
7977 }
78+
79+ /**
80+ * Runs the plan and makes sure the answer matches the result produced by a reference plan.
81+ * @param input the input data to be used.
82+ * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
83+ * the physical operator that's being tested.
84+ * @param expectedPlanFunction a function which accepts the input SparkPlan and uses it to
85+ * instantiate a reference implementation of the physical operator
86+ * that's being tested. The result of executing this plan will be
87+ * treated as the source-of-truth for the test.
88+ */
89+ protected def checkAnswer (
90+ input : DataFrame ,
91+ planFunction : SparkPlan => SparkPlan ,
92+ expectedPlanFunction : SparkPlan => SparkPlan ): Unit = {
93+ SparkPlanTest .checkAnswer(input, planFunction, expectedPlanFunction) match {
94+ case Some (errorMessage) => fail(errorMessage)
95+ case None =>
96+ }
97+ }
8098}
8199
82100/**
83101 * Helper methods for writing tests of individual physical operators.
84102 */
85103object SparkPlanTest {
86104
105+ /**
106+ * Runs the plan and makes sure the answer matches the result produced by a reference plan.
107+ * @param input the input data to be used.
108+ * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
109+ * the physical operator that's being tested.
110+ * @param expectedPlanFunction a function which accepts the input SparkPlan and uses it to
111+ * instantiate a reference implementation of the physical operator
112+ * that's being tested. The result of executing this plan will be
113+ * treated as the source-of-truth for the test.
114+ */
115+ def checkAnswer (
116+ input : DataFrame ,
117+ planFunction : SparkPlan => SparkPlan ,
118+ expectedPlanFunction : SparkPlan => SparkPlan ): Option [String ] = {
119+
120+ val outputPlan = planFunction(input.queryExecution.sparkPlan)
121+ val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan)
122+
123+ val expectedAnswer : Seq [Row ] = try {
124+ executePlan(input, expectedOutputPlan)
125+ } catch {
126+ case NonFatal (e) =>
127+ val errorMessage =
128+ s """
129+ | Exception thrown while executing Spark plan to calculate expected answer:
130+ | $expectedOutputPlan
131+ | == Exception ==
132+ | $e
133+ | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
134+ """ .stripMargin
135+ return Some (errorMessage)
136+ }
137+
138+ val actualAnswer : Seq [Row ] = try {
139+ executePlan(input, outputPlan)
140+ } catch {
141+ case NonFatal (e) =>
142+ val errorMessage =
143+ s """
144+ | Exception thrown while executing Spark plan:
145+ | $outputPlan
146+ | == Exception ==
147+ | $e
148+ | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
149+ """ .stripMargin
150+ return Some (errorMessage)
151+ }
152+
153+ compareAnswers(actualAnswer, expectedAnswer).map { errorMessage =>
154+ s """
155+ | Results do not match.
156+ | Actual result Spark plan:
157+ | $outputPlan
158+ | Expected result Spark plan:
159+ | $expectedOutputPlan
160+ | $errorMessage
161+ """ .stripMargin
162+ }
163+ }
164+
87165 /**
88166 * Runs the plan and makes sure the answer matches the expected result.
89167 * @param input the input data to be used.
@@ -98,22 +176,33 @@ object SparkPlanTest {
98176
99177 val outputPlan = planFunction(input.queryExecution.sparkPlan)
100178
101- // A very simple resolver to make writing tests easier. In contrast to the real resolver
102- // this is always case sensitive and does not try to handle scoping or complex type resolution.
103- val resolvedPlan = outputPlan transform {
104- case plan : SparkPlan =>
105- val inputMap = plan.children.flatMap(_.output).zipWithIndex.map {
106- case (a, i) =>
107- (a.name, BoundReference (i, a.dataType, a.nullable))
108- }.toMap
179+ val sparkAnswer : Seq [Row ] = try {
180+ executePlan(input, outputPlan)
181+ } catch {
182+ case NonFatal (e) =>
183+ val errorMessage =
184+ s """
185+ | Exception thrown while executing Spark plan:
186+ | $outputPlan
187+ | == Exception ==
188+ | $e
189+ | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
190+ """ .stripMargin
191+ return Some (errorMessage)
192+ }
109193
110- plan.transformExpressions {
111- case UnresolvedAttribute (Seq (u)) =>
112- inputMap.getOrElse(u,
113- sys.error(s " Invalid Test: Cannot resolve $u given input $inputMap" ))
114- }
194+ compareAnswers(sparkAnswer, expectedAnswer).map { errorMessage =>
195+ s """
196+ | Results do not match for Spark plan:
197+ | $outputPlan
198+ | $errorMessage
199+ """ .stripMargin
115200 }
201+ }
116202
203+ private def compareAnswers (
204+ sparkAnswer : Seq [Row ],
205+ expectedAnswer : Seq [Row ]): Option [String ] = {
117206 def prepareAnswer (answer : Seq [Row ]): Seq [Row ] = {
118207 // Converts data to types that we can do equality comparison using Scala collections.
119208 // For BigDecimal type, the Scala type has a better definition of equality test (similar to
@@ -130,38 +219,39 @@ object SparkPlanTest {
130219 }
131220 converted.sortBy(_.toString())
132221 }
133-
134- val sparkAnswer : Seq [Row ] = try {
135- resolvedPlan.executeCollect().toSeq
136- } catch {
137- case NonFatal (e) =>
138- val errorMessage =
139- s """
140- | Exception thrown while executing Spark plan:
141- | $outputPlan
142- | == Exception ==
143- | $e
144- | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
145- """ .stripMargin
146- return Some (errorMessage)
147- }
148-
149222 if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
150223 val errorMessage =
151224 s """
152- | Results do not match for Spark plan:
153- | $outputPlan
154225 | == Results ==
155226 | ${sideBySide(
156- s " == Correct Answer - ${expectedAnswer.size} == " +:
227+ s " == Expected Answer - ${expectedAnswer.size} == " +:
157228 prepareAnswer(expectedAnswer).map(_.toString()),
158- s " == Spark Answer - ${sparkAnswer.size} == " +:
229+ s " == Actual Answer - ${sparkAnswer.size} == " +:
159230 prepareAnswer(sparkAnswer).map(_.toString())).mkString(" \n " )}
160231 """ .stripMargin
161- return Some (errorMessage)
232+ Some (errorMessage)
233+ } else {
234+ None
162235 }
236+ }
163237
164- None
238+ private def executePlan (input : DataFrame , outputPlan : SparkPlan ): Seq [Row ] = {
239+ // A very simple resolver to make writing tests easier. In contrast to the real resolver
240+ // this is always case sensitive and does not try to handle scoping or complex type resolution.
241+ val resolvedPlan = outputPlan transform {
242+ case plan : SparkPlan =>
243+ val inputMap = plan.children.flatMap(_.output).zipWithIndex.map {
244+ case (a, i) =>
245+ (a.name, BoundReference (i, a.dataType, a.nullable))
246+ }.toMap
247+
248+ plan.transformExpressions {
249+ case UnresolvedAttribute (Seq (u)) =>
250+ inputMap.getOrElse(u,
251+ sys.error(s " Invalid Test: Cannot resolve $u given input $inputMap" ))
252+ }
253+ }
254+ resolvedPlan.executeCollect().toSeq
165255 }
166256}
167257
0 commit comments