From 9d927262d6bc9a78880703095de521fbcffb9bba Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sun, 8 Dec 2019 19:49:36 -0800 Subject: [PATCH 1/2] fix method calls of checkAnswer --- .../sql/avro/JavaAvroFunctionsSuite.java | 9 +---- .../apache/spark/sql/JavaSaveLoadSuite.java | 5 +-- .../org/apache/spark/sql/QueryTest.scala | 34 +++++++++++++------ .../ReduceNumShufflePartitionsSuite.scala | 25 +++++--------- .../binaryfile/BinaryFileFormatSuite.scala | 6 ++-- .../spark/sql/streaming/StreamSuite.scala | 6 ++-- .../spark/sql/hive/JavaDataFrameSuite.java | 5 +-- .../hive/JavaMetastoreDataSourcesSuite.java | 9 +---- .../execution/AggregationQuerySuite.scala | 2 +- 9 files changed, 44 insertions(+), 57 deletions(-) diff --git a/external/avro/src/test/java/org/apache/spark/sql/avro/JavaAvroFunctionsSuite.java b/external/avro/src/test/java/org/apache/spark/sql/avro/JavaAvroFunctionsSuite.java index a448583dddfb..cf4bba0f7f31 100644 --- a/external/avro/src/test/java/org/apache/spark/sql/avro/JavaAvroFunctionsSuite.java +++ b/external/avro/src/test/java/org/apache/spark/sql/avro/JavaAvroFunctionsSuite.java @@ -44,13 +44,6 @@ public void tearDown() { spark.stop(); } - private static void checkAnswer(Dataset actual, Dataset expected) { - String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected.collectAsList()); - if (errorMessage != null) { - Assert.fail(errorMessage); - } - } - @Test public void testToAvroFromAvro() { Dataset rangeDf = spark.range(10); @@ -69,6 +62,6 @@ public void testToAvroFromAvro() { from_avro(avroDF.col("a"), avroTypeLong), from_avro(avroDF.col("b"), avroTypeStr)); - checkAnswer(actual, df); + QueryTest$.MODULE$.checkAnswer(actual, df.collectAsList()); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaSaveLoadSuite.java index 127d272579a6..875cb913ed7c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaSaveLoadSuite.java @@ -43,10 +43,7 @@ public class JavaSaveLoadSuite { Dataset df; private static void checkAnswer(Dataset actual, List expected) { - String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); - if (errorMessage != null) { - Assert.fail(errorMessage); - } + QueryTest$.MODULE$.checkAnswer(actual, expected); } @Before diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 3039a4ccb677..cd82a9294e2e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -21,6 +21,9 @@ import java.util.{Locale, TimeZone} import scala.collection.JavaConverters._ +import org.junit.Assert +import org.scalatest.Assertions + import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.SQLExecution @@ -150,10 +153,7 @@ abstract class QueryTest extends PlanTest { assertEmptyMissingInput(analyzedDF) - QueryTest.checkAnswer(analyzedDF, expectedAnswer) match { - case Some(errorMessage) => fail(errorMessage) - case None => - } + QueryTest.checkAnswer(analyzedDF, expectedAnswer) } protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = { @@ -235,7 +235,21 @@ abstract class QueryTest extends PlanTest { } } -object QueryTest { +object QueryTest extends Assertions { + /** + * Runs the plan and makes sure the answer matches the expected result. + * + * @param df the [[DataFrame]] to be executed + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param checkToRDD whether to verify deserialization to an RDD. This runs the query twice. + */ + def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row], checkToRDD: Boolean = true): Unit = { + getErrorMessageInCheckAnswer(df, expectedAnswer, checkToRDD) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + /** * Runs the plan and makes sure the answer matches the expected result. * If there was exception during the execution or the contents of the DataFrame does not @@ -246,7 +260,7 @@ object QueryTest { * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. * @param checkToRDD whether to verify deserialization to an RDD. This runs the query twice. */ - def checkAnswer( + def getErrorMessageInCheckAnswer( df: DataFrame, expectedAnswer: Seq[Row], checkToRDD: Boolean = true): Option[String] = { @@ -408,10 +422,10 @@ object QueryTest { } } - def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): String = { - checkAnswer(df, expectedAnswer.asScala) match { - case Some(errorMessage) => errorMessage - case None => null + def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): Unit = { + getErrorMessageInCheckAnswer(df, expectedAnswer.asScala) match { + case Some(errorMessage) => Assert.fail(errorMessage) + case None => } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala index 21ec1ac9bda0..fe07b1ff109b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala @@ -258,13 +258,6 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA val numInputPartitions: Int = 10 - def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { - QueryTest.checkAnswer(actual, expectedAnswer) match { - case Some(errorMessage) => fail(errorMessage) - case None => - } - } - def withSparkSession( f: SparkSession => Unit, targetPostShuffleInputSize: Int, @@ -309,7 +302,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA val agg = df.groupBy("key").count() // Check the answer first. - checkAnswer( + QueryTest.checkAnswer( agg, spark.range(0, 20).selectExpr("id", "50 as cnt").collect()) @@ -356,7 +349,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA .range(0, 1000) .selectExpr("id % 500 as key", "id as value") .union(spark.range(0, 1000).selectExpr("id % 500 as key", "id as value")) - checkAnswer( + QueryTest.checkAnswer( join, expectedAnswer.collect()) @@ -408,7 +401,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA spark .range(0, 500) .selectExpr("id", "2 as cnt") - checkAnswer( + QueryTest.checkAnswer( join, expectedAnswer.collect()) @@ -460,7 +453,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA spark .range(0, 1000) .selectExpr("id % 500 as key", "2 as cnt", "id as value") - checkAnswer( + QueryTest.checkAnswer( join, expectedAnswer.collect()) @@ -504,7 +497,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA // Check the answer first. val expectedAnswer = spark.range(0, 500).selectExpr("id % 500", "id as value") .union(spark.range(500, 1000).selectExpr("id % 500", "id as value")) - checkAnswer( + QueryTest.checkAnswer( join, expectedAnswer.collect()) @@ -534,7 +527,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA // ReusedQueryStage 0 // ReusedQueryStage 0 val resultDf = df.join(df, "key").join(df, "key") - checkAnswer(resultDf, Row(0, 0, 0, 0) :: Nil) + QueryTest.checkAnswer(resultDf, Row(0, 0, 0, 0) :: Nil) val finalPlan = resultDf.queryExecution.executedPlan .asInstanceOf[AdaptiveSparkPlanExec].executedPlan assert(finalPlan.collect { case p: ReusedQueryStageExec => p }.length == 2) @@ -550,7 +543,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA val grouped = df.groupBy("key").agg(max("value").as("value")) val resultDf2 = grouped.groupBy(col("key") + 1).max("value") .union(grouped.groupBy(col("key") + 2).max("value")) - checkAnswer(resultDf2, Row(1, 0) :: Row(2, 0) :: Nil) + QueryTest.checkAnswer(resultDf2, Row(1, 0) :: Row(2, 0) :: Nil) val finalPlan2 = resultDf2.queryExecution.executedPlan .asInstanceOf[AdaptiveSparkPlanExec].executedPlan @@ -580,7 +573,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA val ds = spark.range(3) val resultDf = ds.repartition(2, ds.col("id")).toDF() - checkAnswer(resultDf, + QueryTest.checkAnswer(resultDf, Seq(0, 1, 2).map(i => Row(i))) val finalPlan = resultDf.queryExecution.executedPlan .asInstanceOf[AdaptiveSparkPlanExec].executedPlan @@ -596,7 +589,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA val resultDf = df1.union(df2) - checkAnswer(resultDf, Seq((0), (1), (2), (3)).map(i => Row(i))) + QueryTest.checkAnswer(resultDf, Seq((0), (1), (2), (3)).map(i => Row(i))) val finalPlan = resultDf.queryExecution.executedPlan .asInstanceOf[AdaptiveSparkPlanExec].executedPlan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala index 70ec9bbf4819..2cd142f91307 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala @@ -352,15 +352,15 @@ class BinaryFileFormatSuite extends QueryTest with SharedSparkSession { .select(CONTENT) } val expected = Seq(Row(content)) - QueryTest.checkAnswer(readContent(), expected) + checkAnswer(readContent(), expected) withSQLConf(SOURCES_BINARY_FILE_MAX_LENGTH.key -> content.length.toString) { - QueryTest.checkAnswer(readContent(), expected) + checkAnswer(readContent(), expected) } // Disable read. If the implementation attempts to read, the exception would be different. file.setReadable(false) val caught = intercept[SparkException] { withSQLConf(SOURCES_BINARY_FILE_MAX_LENGTH.key -> (content.length - 1).toString) { - QueryTest.checkAnswer(readContent(), expected) + checkAnswer(readContent(), expected) } } assert(caught.getMessage.contains("exceeds the max length allowed")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 75ad041ccb80..a637b42c6b03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -755,9 +755,9 @@ class StreamSuite extends StreamTest { inputData.addData(9) streamingQuery.processAllAvailable() - QueryTest.checkAnswer(spark.table("counts").toDF(), - Row("1", 1) :: Row("2", 1) :: Row("3", 2) :: Row("4", 2) :: - Row("5", 2) :: Row("6", 2) :: Row("7", 1) :: Row("8", 1) :: Row("9", 1) :: Nil) + checkAnswer(spark.table("counts").toDF(), + Row(1, 1L) :: Row(2, 1L) :: Row(3, 2L) :: Row(4, 2L) :: + Row(5, 2L) :: Row(6, 2L) :: Row(7, 1L) :: Row(8, 1L) :: Row(9, 1L) :: Nil) } finally { if (streamingQuery ne null) { streamingQuery.stop() diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java index 636ce10da373..ffbccc90b45a 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -39,10 +39,7 @@ public class JavaDataFrameSuite { Dataset df; private static void checkAnswer(Dataset actual, List expected) { - String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); - if (errorMessage != null) { - Assert.fail(errorMessage); - } + QueryTest$.MODULE$.checkAnswer(actual, expected); } @Before diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java index 5304052b45a4..ba58b92ddb02 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java @@ -51,13 +51,6 @@ public class JavaMetastoreDataSourcesSuite { FileSystem fs; Dataset df; - private static void checkAnswer(Dataset actual, List expected) { - String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected); - if (errorMessage != null) { - Assert.fail(errorMessage); - } - } - @Before public void setUp() throws IOException { sqlContext = TestHive$.MODULE$; @@ -100,7 +93,7 @@ public void saveTableAndQueryIt() { .options(options) .saveAsTable("javaSavedTable"); - checkAnswer( + QueryTest$.MODULE$.checkAnswer( sqlContext.sql("SELECT * FROM javaSavedTable"), df.collectAsList()); } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 4a3277f5a7e4..f84b854048e8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -1063,7 +1063,7 @@ class HashAggregationQueryWithControlledFallbackSuite extends AggregationQuerySu // todo: remove it? val newActual = Dataset.ofRows(spark, actual.logicalPlan) - QueryTest.checkAnswer(newActual, expectedAnswer) match { + QueryTest.getErrorMessageInCheckAnswer(newActual, expectedAnswer) match { case Some(errorMessage) => val newErrorMessage = s""" From 7fd4057c5bd2c3c71426319d087b9ea58a8aff0e Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sun, 8 Dec 2019 20:49:40 -0800 Subject: [PATCH 2/2] remove [[]] in code comment --- .../test/scala/org/apache/spark/sql/QueryTest.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index cd82a9294e2e..4a21ae924203 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -239,8 +239,8 @@ object QueryTest extends Assertions { /** * Runs the plan and makes sure the answer matches the expected result. * - * @param df the [[DataFrame]] to be executed - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param df the DataFrame to be executed + * @param expectedAnswer the expected result in a Seq of Rows. * @param checkToRDD whether to verify deserialization to an RDD. This runs the query twice. */ def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row], checkToRDD: Boolean = true): Unit = { @@ -253,11 +253,11 @@ object QueryTest extends Assertions { /** * Runs the plan and makes sure the answer matches the expected result. * If there was exception during the execution or the contents of the DataFrame does not - * match the expected result, an error message will be returned. Otherwise, a [[None]] will + * match the expected result, an error message will be returned. Otherwise, a None will * be returned. * - * @param df the [[DataFrame]] to be executed - * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param df the DataFrame to be executed + * @param expectedAnswer the expected result in a Seq of Rows. * @param checkToRDD whether to verify deserialization to an RDD. This runs the query twice. */ def getErrorMessageInCheckAnswer(