From 9434fc767e6b7907b9beacb2f4358d767c9d4d32 Mon Sep 17 00:00:00 2001 From: donnyzone Date: Tue, 15 Aug 2017 10:29:25 +0800 Subject: [PATCH 1/4] move test case to DataFrameAggregateSuite --- .../spark/sql/DataFrameAggregateSuite.scala | 47 +++++++++++++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 45 ------------------ 2 files changed, 47 insertions(+), 45 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 69ea62ef5eb7..32c2777c407e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import org.apache.spark.sql.execution.WholeStageCodegenExec +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -24,6 +26,8 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData.DecimalData import org.apache.spark.sql.types.{Decimal, DecimalType} +import scala.util.Random + case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double) class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { @@ -573,4 +577,47 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { spark.sql("SELECT 3 AS c, 4 AS d, SUM(b) FROM testData2 GROUP BY c, d"), Seq(Row(3, 4, 9))) } + + private def assertNoExceptions(c: Column): Unit = { + for ((wholeStage, useObjectHashAgg) <- + Seq((true, true), (true, false), (false, true), (false, false))) { + withSQLConf( + (SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString), + (SQLConf.USE_OBJECT_HASH_AGG.key, useObjectHashAgg.toString)) { + + val df = Seq(("1", 1), ("1", 2), ("2", 3), ("2", 4)).toDF("x", "y") + + // HashAggregate test case + val hashAggDF = df.groupBy("x").agg(c, sum("y")) + val hashAggPlan = hashAggDF.queryExecution.executedPlan + if (wholeStage) { + assert(hashAggPlan.find { + case WholeStageCodegenExec(_: HashAggregateExec) => true + case _ => false + }.isDefined) + } else { + assert(hashAggPlan.isInstanceOf[HashAggregateExec]) + } + hashAggDF.collect() + + // ObjectHashAggregate and SortAggregate test case + val objHashAggOrSortAggDF = df.groupBy("x").agg(c, collect_list("y")) + val objHashAggOrSortAggPlan = objHashAggOrSortAggDF.queryExecution.executedPlan + if (useObjectHashAgg) { + assert(objHashAggOrSortAggPlan.isInstanceOf[ObjectHashAggregateExec]) + } else { + assert(objHashAggOrSortAggPlan.isInstanceOf[SortAggregateExec]) + } + objHashAggOrSortAggDF.collect() + } + } + } + + test("SPARK-19471: AggregationIterator does not initialize the generated result projection" + + " before using it") { + Seq( + monotonically_increasing_id(), spark_partition_id(), + rand(Random.nextLong()), randn(Random.nextLong()) + ).foreach(assertNoExceptions) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index fdb9f1d1e0e9..0681b9cbeb1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -24,8 +24,6 @@ import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.execution.WholeStageCodegenExec -import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -451,49 +449,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_)) } - private def assertNoExceptions(c: Column): Unit = { - for ((wholeStage, useObjectHashAgg) <- - Seq((true, true), (true, false), (false, true), (false, false))) { - withSQLConf( - (SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString), - (SQLConf.USE_OBJECT_HASH_AGG.key, useObjectHashAgg.toString)) { - - val df = Seq(("1", 1), ("1", 2), ("2", 3), ("2", 4)).toDF("x", "y") - - // HashAggregate test case - val hashAggDF = df.groupBy("x").agg(c, sum("y")) - val hashAggPlan = hashAggDF.queryExecution.executedPlan - if (wholeStage) { - assert(hashAggPlan.find { - case WholeStageCodegenExec(_: HashAggregateExec) => true - case _ => false - }.isDefined) - } else { - assert(hashAggPlan.isInstanceOf[HashAggregateExec]) - } - hashAggDF.collect() - - // ObjectHashAggregate and SortAggregate test case - val objHashAggOrSortAggDF = df.groupBy("x").agg(c, collect_list("y")) - val objHashAggOrSortAggPlan = objHashAggOrSortAggDF.queryExecution.executedPlan - if (useObjectHashAgg) { - assert(objHashAggOrSortAggPlan.isInstanceOf[ObjectHashAggregateExec]) - } else { - assert(objHashAggOrSortAggPlan.isInstanceOf[SortAggregateExec]) - } - objHashAggOrSortAggDF.collect() - } - } - } - - test("SPARK-19471: AggregationIterator does not initialize the generated result projection" + - " before using it") { - Seq( - monotonically_increasing_id(), spark_partition_id(), - rand(Random.nextLong()), randn(Random.nextLong()) - ).foreach(assertNoExceptions) - } - test("SPARK-21281 use string types by default if array and map have no argument") { val ds = spark.range(1) var expectedSchema = new StructType() From f057ff8400076fce615fe9b6521ed1b3d66cb669 Mon Sep 17 00:00:00 2001 From: donnyzone Date: Tue, 15 Aug 2017 10:37:33 +0800 Subject: [PATCH 2/4] change import order --- .../scala/org/apache/spark/sql/DataFrameAggregateSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 32c2777c407e..e61581ae63c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import scala.util.Random + import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.expressions.Window @@ -26,7 +28,6 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData.DecimalData import org.apache.spark.sql.types.{Decimal, DecimalType} -import scala.util.Random case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double) From 9f2ec8f2c2465d12d06c944b814f5363a99a0271 Mon Sep 17 00:00:00 2001 From: donnyzone Date: Tue, 15 Aug 2017 14:18:00 +0800 Subject: [PATCH 3/4] fix code style --- .../spark/sql/DataFrameAggregateSuite.scala | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index e61581ae63c3..aba12bd41076 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData.DecimalData import org.apache.spark.sql.types.{Decimal, DecimalType} - case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double) class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { @@ -563,22 +562,6 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { assert(e.message.contains("aggregate functions are not allowed in GROUP BY")) } - test("SPARK-21580 ints in aggregation expressions are taken as group-by ordinal.") { - checkAnswer( - testData2.groupBy(lit(3), lit(4)).agg(lit(6), lit(7), sum("b")), - Seq(Row(3, 4, 6, 7, 9))) - checkAnswer( - testData2.groupBy(lit(3), lit(4)).agg(lit(6), 'b, sum("b")), - Seq(Row(3, 4, 6, 1, 3), Row(3, 4, 6, 2, 6))) - - checkAnswer( - spark.sql("SELECT 3, 4, SUM(b) FROM testData2 GROUP BY 1, 2"), - Seq(Row(3, 4, 9))) - checkAnswer( - spark.sql("SELECT 3 AS c, 4 AS d, SUM(b) FROM testData2 GROUP BY c, d"), - Seq(Row(3, 4, 9))) - } - private def assertNoExceptions(c: Column): Unit = { for ((wholeStage, useObjectHashAgg) <- Seq((true, true), (true, false), (false, true), (false, false))) { @@ -621,4 +604,20 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { rand(Random.nextLong()), randn(Random.nextLong()) ).foreach(assertNoExceptions) } + + test("SPARK-21580 ints in aggregation expressions are taken as group-by ordinal.") { + checkAnswer( + testData2.groupBy(lit(3), lit(4)).agg(lit(6), lit(7), sum("b")), + Seq(Row(3, 4, 6, 7, 9))) + checkAnswer( + testData2.groupBy(lit(3), lit(4)).agg(lit(6), 'b, sum("b")), + Seq(Row(3, 4, 6, 1, 3), Row(3, 4, 6, 2, 6))) + + checkAnswer( + spark.sql("SELECT 3, 4, SUM(b) FROM testData2 GROUP BY 1, 2"), + Seq(Row(3, 4, 9))) + checkAnswer( + spark.sql("SELECT 3 AS c, 4 AS d, SUM(b) FROM testData2 GROUP BY c, d"), + Seq(Row(3, 4, 9))) + } } From ab61839f90c95fabce47907e0ce649db301c56c9 Mon Sep 17 00:00:00 2001 From: donnyzone Date: Tue, 15 Aug 2017 17:35:09 +0800 Subject: [PATCH 4/4] rewrite comment --- .../scala/org/apache/spark/sql/DataFrameAggregateSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index aba12bd41076..affe97120c8f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -571,7 +571,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { val df = Seq(("1", 1), ("1", 2), ("2", 3), ("2", 4)).toDF("x", "y") - // HashAggregate test case + // test case for HashAggregate val hashAggDF = df.groupBy("x").agg(c, sum("y")) val hashAggPlan = hashAggDF.queryExecution.executedPlan if (wholeStage) { @@ -584,7 +584,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } hashAggDF.collect() - // ObjectHashAggregate and SortAggregate test case + // test case for ObjectHashAggregate and SortAggregate val objHashAggOrSortAggDF = df.groupBy("x").agg(c, collect_list("y")) val objHashAggOrSortAggPlan = objHashAggOrSortAggDF.queryExecution.executedPlan if (useObjectHashAgg) {