diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 69ea62ef5eb7..affe97120c8f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql +import scala.util.Random + +import org.apache.spark.sql.execution.WholeStageCodegenExec +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -558,6 +562,49 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { assert(e.message.contains("aggregate functions are not allowed in GROUP BY")) } + private def assertNoExceptions(c: Column): Unit = { + for ((wholeStage, useObjectHashAgg) <- + Seq((true, true), (true, false), (false, true), (false, false))) { + withSQLConf( + (SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString), + (SQLConf.USE_OBJECT_HASH_AGG.key, useObjectHashAgg.toString)) { + + val df = Seq(("1", 1), ("1", 2), ("2", 3), ("2", 4)).toDF("x", "y") + + // test case for HashAggregate + val hashAggDF = df.groupBy("x").agg(c, sum("y")) + val hashAggPlan = hashAggDF.queryExecution.executedPlan + if (wholeStage) { + assert(hashAggPlan.find { + case WholeStageCodegenExec(_: HashAggregateExec) => true + case _ => false + }.isDefined) + } else { + assert(hashAggPlan.isInstanceOf[HashAggregateExec]) + } + hashAggDF.collect() + + // test case for ObjectHashAggregate and SortAggregate + val objHashAggOrSortAggDF = df.groupBy("x").agg(c, collect_list("y")) + val objHashAggOrSortAggPlan = objHashAggOrSortAggDF.queryExecution.executedPlan + if (useObjectHashAgg) { + assert(objHashAggOrSortAggPlan.isInstanceOf[ObjectHashAggregateExec]) + } else { + assert(objHashAggOrSortAggPlan.isInstanceOf[SortAggregateExec]) + } + objHashAggOrSortAggDF.collect() + } + } + } + + test("SPARK-19471: AggregationIterator does not initialize the generated result projection" + + " before using it") { + Seq( + monotonically_increasing_id(), spark_partition_id(), + rand(Random.nextLong()), randn(Random.nextLong()) + ).foreach(assertNoExceptions) + } + test("SPARK-21580 ints in aggregation expressions are taken as group-by ordinal.") { checkAnswer( testData2.groupBy(lit(3), lit(4)).agg(lit(6), lit(7), sum("b")), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index fdb9f1d1e0e9..0681b9cbeb1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -24,8 +24,6 @@ import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.execution.WholeStageCodegenExec -import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -451,49 +449,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_)) } - private def assertNoExceptions(c: Column): Unit = { - for ((wholeStage, useObjectHashAgg) <- - Seq((true, true), (true, false), (false, true), (false, false))) { - withSQLConf( - (SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString), - (SQLConf.USE_OBJECT_HASH_AGG.key, useObjectHashAgg.toString)) { - - val df = Seq(("1", 1), ("1", 2), ("2", 3), ("2", 4)).toDF("x", "y") - - // HashAggregate test case - val hashAggDF = df.groupBy("x").agg(c, sum("y")) - val hashAggPlan = hashAggDF.queryExecution.executedPlan - if (wholeStage) { - assert(hashAggPlan.find { - case WholeStageCodegenExec(_: HashAggregateExec) => true - case _ => false - }.isDefined) - } else { - assert(hashAggPlan.isInstanceOf[HashAggregateExec]) - } - hashAggDF.collect() - - // ObjectHashAggregate and SortAggregate test case - val objHashAggOrSortAggDF = df.groupBy("x").agg(c, collect_list("y")) - val objHashAggOrSortAggPlan = objHashAggOrSortAggDF.queryExecution.executedPlan - if (useObjectHashAgg) { - assert(objHashAggOrSortAggPlan.isInstanceOf[ObjectHashAggregateExec]) - } else { - assert(objHashAggOrSortAggPlan.isInstanceOf[SortAggregateExec]) - } - objHashAggOrSortAggDF.collect() - } - } - } - - test("SPARK-19471: AggregationIterator does not initialize the generated result projection" + - " before using it") { - Seq( - monotonically_increasing_id(), spark_partition_id(), - rand(Random.nextLong()), randn(Random.nextLong()) - ).foreach(assertNoExceptions) - } - test("SPARK-21281 use string types by default if array and map have no argument") { val ds = spark.range(1) var expectedSchema = new StructType()