Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

package org.apache.spark.sql

import scala.util.Random

import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -558,6 +562,49 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
assert(e.message.contains("aggregate functions are not allowed in GROUP BY"))
}

private def assertNoExceptions(c: Column): Unit = {
for ((wholeStage, useObjectHashAgg) <-
Seq((true, true), (true, false), (false, true), (false, false))) {
withSQLConf(
(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString),
(SQLConf.USE_OBJECT_HASH_AGG.key, useObjectHashAgg.toString)) {

val df = Seq(("1", 1), ("1", 2), ("2", 3), ("2", 4)).toDF("x", "y")

// test case for HashAggregate
val hashAggDF = df.groupBy("x").agg(c, sum("y"))
val hashAggPlan = hashAggDF.queryExecution.executedPlan
if (wholeStage) {
assert(hashAggPlan.find {
case WholeStageCodegenExec(_: HashAggregateExec) => true
case _ => false
}.isDefined)
} else {
assert(hashAggPlan.isInstanceOf[HashAggregateExec])
}
hashAggDF.collect()

// test case for ObjectHashAggregate and SortAggregate
val objHashAggOrSortAggDF = df.groupBy("x").agg(c, collect_list("y"))
val objHashAggOrSortAggPlan = objHashAggOrSortAggDF.queryExecution.executedPlan
if (useObjectHashAgg) {
assert(objHashAggOrSortAggPlan.isInstanceOf[ObjectHashAggregateExec])
} else {
assert(objHashAggOrSortAggPlan.isInstanceOf[SortAggregateExec])
}
objHashAggOrSortAggDF.collect()
}
}
}

test("SPARK-19471: AggregationIterator does not initialize the generated result projection" +
" before using it") {
Seq(
monotonically_increasing_id(), spark_partition_id(),
rand(Random.nextLong()), randn(Random.nextLong())
).foreach(assertNoExceptions)
}

test("SPARK-21580 ints in aggregation expressions are taken as group-by ordinal.") {
checkAnswer(
testData2.groupBy(lit(3), lit(4)).agg(lit(6), lit(7), sum("b")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ import scala.util.Random
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
Expand Down Expand Up @@ -451,49 +449,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_))
}

private def assertNoExceptions(c: Column): Unit = {
for ((wholeStage, useObjectHashAgg) <-
Seq((true, true), (true, false), (false, true), (false, false))) {
withSQLConf(
(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString),
(SQLConf.USE_OBJECT_HASH_AGG.key, useObjectHashAgg.toString)) {

val df = Seq(("1", 1), ("1", 2), ("2", 3), ("2", 4)).toDF("x", "y")

// HashAggregate test case
val hashAggDF = df.groupBy("x").agg(c, sum("y"))
val hashAggPlan = hashAggDF.queryExecution.executedPlan
if (wholeStage) {
assert(hashAggPlan.find {
case WholeStageCodegenExec(_: HashAggregateExec) => true
case _ => false
}.isDefined)
} else {
assert(hashAggPlan.isInstanceOf[HashAggregateExec])
}
hashAggDF.collect()

// ObjectHashAggregate and SortAggregate test case
val objHashAggOrSortAggDF = df.groupBy("x").agg(c, collect_list("y"))
val objHashAggOrSortAggPlan = objHashAggOrSortAggDF.queryExecution.executedPlan
if (useObjectHashAgg) {
assert(objHashAggOrSortAggPlan.isInstanceOf[ObjectHashAggregateExec])
} else {
assert(objHashAggOrSortAggPlan.isInstanceOf[SortAggregateExec])
}
objHashAggOrSortAggDF.collect()
}
}
}

test("SPARK-19471: AggregationIterator does not initialize the generated result projection" +
" before using it") {
Seq(
monotonically_increasing_id(), spark_partition_id(),
rand(Random.nextLong()), randn(Random.nextLong())
).foreach(assertNoExceptions)
}

test("SPARK-21281 use string types by default if array and map have no argument") {
val ds = spark.range(1)
var expectedSchema = new StructType()
Expand Down