Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,6 @@ public void tearDown() {
spark.stop();
}

private static void checkAnswer(Dataset<Row> actual, Dataset<Row> expected) {
String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected.collectAsList());
if (errorMessage != null) {
Assert.fail(errorMessage);
}
}

@Test
public void testToAvroFromAvro() {
Dataset<Long> rangeDf = spark.range(10);
Expand All @@ -69,6 +62,6 @@ public void testToAvroFromAvro() {
from_avro(avroDF.col("a"), avroTypeLong),
from_avro(avroDF.col("b"), avroTypeStr));

checkAnswer(actual, df);
QueryTest$.MODULE$.checkAnswer(actual, df.collectAsList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,7 @@ public class JavaSaveLoadSuite {
Dataset<Row> df;

private static void checkAnswer(Dataset<Row> actual, List<Row> expected) {
String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected);
if (errorMessage != null) {
Assert.fail(errorMessage);
}
QueryTest$.MODULE$.checkAnswer(actual, expected);
}

@Before
Expand Down
40 changes: 27 additions & 13 deletions sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ import java.util.{Locale, TimeZone}

import scala.collection.JavaConverters._

import org.junit.Assert
import org.scalatest.Assertions

import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.SQLExecution
Expand Down Expand Up @@ -150,10 +153,7 @@ abstract class QueryTest extends PlanTest {

assertEmptyMissingInput(analyzedDF)

QueryTest.checkAnswer(analyzedDF, expectedAnswer) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
QueryTest.checkAnswer(analyzedDF, expectedAnswer)
}

protected def checkAnswer(df: => DataFrame, expectedAnswer: Row): Unit = {
Expand Down Expand Up @@ -235,18 +235,32 @@ abstract class QueryTest extends PlanTest {
}
}

object QueryTest {
object QueryTest extends Assertions {
/**
* Runs the plan and makes sure the answer matches the expected result.
*
* @param df the DataFrame to be executed
* @param expectedAnswer the expected result in a Seq of Rows.
* @param checkToRDD whether to verify deserialization to an RDD. This runs the query twice.
*/
def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row], checkToRDD: Boolean = true): Unit = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can make it private if this isn't supposed to be called in the tests directly. companion class still can access.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some suites needs this one: ReduceNumShufflePartitionsSuite, SessionStateSuite and SQLQuerySuite(SQLQuerySuite needs to set the parameter checkToRDD as false)
And I prefer not to have duplicated code in these 3 suites.

getErrorMessageInCheckAnswer(df, expectedAnswer, checkToRDD) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
}

/**
* Runs the plan and makes sure the answer matches the expected result.
* If there was exception during the execution or the contents of the DataFrame does not
* match the expected result, an error message will be returned. Otherwise, a [[None]] will
* match the expected result, an error message will be returned. Otherwise, a None will
* be returned.
*
* @param df the [[DataFrame]] to be executed
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
* @param df the DataFrame to be executed
* @param expectedAnswer the expected result in a Seq of Rows.
* @param checkToRDD whether to verify deserialization to an RDD. This runs the query twice.
*/
def checkAnswer(
def getErrorMessageInCheckAnswer(
df: DataFrame,
expectedAnswer: Seq[Row],
checkToRDD: Boolean = true): Option[String] = {
Expand Down Expand Up @@ -408,10 +422,10 @@ object QueryTest {
}
}

def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): String = {
checkAnswer(df, expectedAnswer.asScala) match {
case Some(errorMessage) => errorMessage
case None => null
def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): Unit = {
getErrorMessageInCheckAnswer(df, expectedAnswer.asScala) match {
case Some(errorMessage) => Assert.fail(errorMessage)
case None =>
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,13 +258,6 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA

val numInputPartitions: Int = 10

def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = {
QueryTest.checkAnswer(actual, expectedAnswer) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
}

def withSparkSession(
f: SparkSession => Unit,
targetPostShuffleInputSize: Int,
Expand Down Expand Up @@ -309,7 +302,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA
val agg = df.groupBy("key").count()

// Check the answer first.
checkAnswer(
QueryTest.checkAnswer(
agg,
spark.range(0, 20).selectExpr("id", "50 as cnt").collect())

Expand Down Expand Up @@ -356,7 +349,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA
.range(0, 1000)
.selectExpr("id % 500 as key", "id as value")
.union(spark.range(0, 1000).selectExpr("id % 500 as key", "id as value"))
checkAnswer(
QueryTest.checkAnswer(
join,
expectedAnswer.collect())

Expand Down Expand Up @@ -408,7 +401,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA
spark
.range(0, 500)
.selectExpr("id", "2 as cnt")
checkAnswer(
QueryTest.checkAnswer(
join,
expectedAnswer.collect())

Expand Down Expand Up @@ -460,7 +453,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA
spark
.range(0, 1000)
.selectExpr("id % 500 as key", "2 as cnt", "id as value")
checkAnswer(
QueryTest.checkAnswer(
join,
expectedAnswer.collect())

Expand Down Expand Up @@ -504,7 +497,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA
// Check the answer first.
val expectedAnswer = spark.range(0, 500).selectExpr("id % 500", "id as value")
.union(spark.range(500, 1000).selectExpr("id % 500", "id as value"))
checkAnswer(
QueryTest.checkAnswer(
join,
expectedAnswer.collect())

Expand Down Expand Up @@ -534,7 +527,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA
// ReusedQueryStage 0
// ReusedQueryStage 0
val resultDf = df.join(df, "key").join(df, "key")
checkAnswer(resultDf, Row(0, 0, 0, 0) :: Nil)
QueryTest.checkAnswer(resultDf, Row(0, 0, 0, 0) :: Nil)
val finalPlan = resultDf.queryExecution.executedPlan
.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
assert(finalPlan.collect { case p: ReusedQueryStageExec => p }.length == 2)
Expand All @@ -550,7 +543,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA
val grouped = df.groupBy("key").agg(max("value").as("value"))
val resultDf2 = grouped.groupBy(col("key") + 1).max("value")
.union(grouped.groupBy(col("key") + 2).max("value"))
checkAnswer(resultDf2, Row(1, 0) :: Row(2, 0) :: Nil)
QueryTest.checkAnswer(resultDf2, Row(1, 0) :: Row(2, 0) :: Nil)

val finalPlan2 = resultDf2.queryExecution.executedPlan
.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
Expand Down Expand Up @@ -580,7 +573,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA
val ds = spark.range(3)
val resultDf = ds.repartition(2, ds.col("id")).toDF()

checkAnswer(resultDf,
QueryTest.checkAnswer(resultDf,
Seq(0, 1, 2).map(i => Row(i)))
val finalPlan = resultDf.queryExecution.executedPlan
.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
Expand All @@ -596,7 +589,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA

val resultDf = df1.union(df2)

checkAnswer(resultDf, Seq((0), (1), (2), (3)).map(i => Row(i)))
QueryTest.checkAnswer(resultDf, Seq((0), (1), (2), (3)).map(i => Row(i)))

val finalPlan = resultDf.queryExecution.executedPlan
.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -352,15 +352,15 @@ class BinaryFileFormatSuite extends QueryTest with SharedSparkSession {
.select(CONTENT)
}
val expected = Seq(Row(content))
QueryTest.checkAnswer(readContent(), expected)
checkAnswer(readContent(), expected)
withSQLConf(SOURCES_BINARY_FILE_MAX_LENGTH.key -> content.length.toString) {
QueryTest.checkAnswer(readContent(), expected)
checkAnswer(readContent(), expected)
}
// Disable read. If the implementation attempts to read, the exception would be different.
file.setReadable(false)
val caught = intercept[SparkException] {
withSQLConf(SOURCES_BINARY_FILE_MAX_LENGTH.key -> (content.length - 1).toString) {
QueryTest.checkAnswer(readContent(), expected)
checkAnswer(readContent(), expected)
}
}
assert(caught.getMessage.contains("exceeds the max length allowed"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -755,9 +755,9 @@ class StreamSuite extends StreamTest {
inputData.addData(9)
streamingQuery.processAllAvailable()

QueryTest.checkAnswer(spark.table("counts").toDF(),
Row("1", 1) :: Row("2", 1) :: Row("3", 2) :: Row("4", 2) ::
Row("5", 2) :: Row("6", 2) :: Row("7", 1) :: Row("8", 1) :: Row("9", 1) :: Nil)
checkAnswer(spark.table("counts").toDF(),
Row(1, 1L) :: Row(2, 1L) :: Row(3, 2L) :: Row(4, 2L) ::
Row(5, 2L) :: Row(6, 2L) :: Row(7, 1L) :: Row(8, 1L) :: Row(9, 1L) :: Nil)
} finally {
if (streamingQuery ne null) {
streamingQuery.stop()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,7 @@ public class JavaDataFrameSuite {
Dataset<Row> df;

private static void checkAnswer(Dataset<Row> actual, List<Row> expected) {
String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected);
if (errorMessage != null) {
Assert.fail(errorMessage);
}
QueryTest$.MODULE$.checkAnswer(actual, expected);
}

@Before
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,6 @@ public class JavaMetastoreDataSourcesSuite {
FileSystem fs;
Dataset<Row> df;

private static void checkAnswer(Dataset<Row> actual, List<Row> expected) {
String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected);
if (errorMessage != null) {
Assert.fail(errorMessage);
}
}

@Before
public void setUp() throws IOException {
sqlContext = TestHive$.MODULE$;
Expand Down Expand Up @@ -100,7 +93,7 @@ public void saveTableAndQueryIt() {
.options(options)
.saveAsTable("javaSavedTable");

checkAnswer(
QueryTest$.MODULE$.checkAnswer(
sqlContext.sql("SELECT * FROM javaSavedTable"),
df.collectAsList());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,7 @@ class HashAggregationQueryWithControlledFallbackSuite extends AggregationQuerySu
// todo: remove it?
val newActual = Dataset.ofRows(spark, actual.logicalPlan)

QueryTest.checkAnswer(newActual, expectedAnswer) match {
QueryTest.getErrorMessageInCheckAnswer(newActual, expectedAnswer) match {
case Some(errorMessage) =>
val newErrorMessage =
s"""
Expand Down