diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 7c11fdb9792e..98c4a5129995 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ * is used to generate result. */ abstract class AggregationIterator( + partIndex: Int, groupingExpressions: Seq[NamedExpression], inputAttributes: Seq[Attribute], aggregateExpressions: Seq[AggregateExpression], @@ -217,6 +218,7 @@ abstract class AggregationIterator( val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateAttributes) + resultProjection.initialize(partIndex) (currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => { // Generate results for all expression-based aggregate functions. @@ -235,6 +237,7 @@ abstract class AggregationIterator( val resultProjection = UnsafeProjection.create( groupingAttributes ++ bufferAttributes, groupingAttributes ++ bufferAttributes) + resultProjection.initialize(partIndex) // TypedImperativeAggregate stores generic object in aggregation buffer, and requires // calling serialization before shuffling. See [[TypedImperativeAggregate]] for more info. @@ -256,6 +259,7 @@ abstract class AggregationIterator( } else { // Grouping-only: we only output values based on grouping expressions. val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes) + resultProjection.initialize(partIndex) (currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => { resultProjection(currentGroupingKey) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 56f61c30c4a3..80ea45868786 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -96,7 +96,7 @@ case class HashAggregateExec( val spillSize = longMetric("spillSize") val avgHashProbe = longMetric("avgHashProbe") - child.execute().mapPartitions { iter => + child.execute().mapPartitionsWithIndex { (partIndex, iter) => val hasInput = iter.hasNext if (!hasInput && groupingExpressions.nonEmpty) { @@ -106,6 +106,7 @@ case class HashAggregateExec( } else { val aggregationIterator = new TungstenAggregationIterator( + partIndex, groupingExpressions, aggregateExpressions, aggregateAttributes, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index eef2c4e843f3..c68dbc73f044 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -31,6 +31,7 @@ import org.apache.spark.unsafe.KVIterator import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter class ObjectAggregationIterator( + partIndex: Int, outputAttributes: Seq[Attribute], groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], @@ -43,6 +44,7 @@ class ObjectAggregationIterator( fallbackCountThreshold: Int, numOutputRows: SQLMetric) extends AggregationIterator( + partIndex, groupingExpressions, originalInputAttributes, aggregateExpressions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index b53521b1b6ba..6316e06a8f34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -98,7 +98,7 @@ case class ObjectHashAggregateExec( val numOutputRows = longMetric("numOutputRows") val fallbackCountThreshold = sqlContext.conf.objectAggSortBasedFallbackThreshold - child.execute().mapPartitionsInternal { iter => + child.execute().mapPartitionsWithIndexInternal { (partIndex, iter) => val hasInput = iter.hasNext if (!hasInput && groupingExpressions.nonEmpty) { // This is a grouped aggregate and the input kvIterator is empty, @@ -107,6 +107,7 @@ case class ObjectHashAggregateExec( } else { val aggregationIterator = new ObjectAggregationIterator( + partIndex, child.output, groupingExpressions, aggregateExpressions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index be3198b8e7d8..a43235790834 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -74,7 +74,7 @@ case class SortAggregateExec( protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { val numOutputRows = longMetric("numOutputRows") - child.execute().mapPartitionsInternal { iter => + child.execute().mapPartitionsWithIndexInternal { (partIndex, iter) => // Because the constructor of an aggregation iterator will read at least the first row, // we need to get the value of iter.hasNext first. val hasInput = iter.hasNext @@ -84,6 +84,7 @@ case class SortAggregateExec( Iterator[UnsafeRow]() } else { val outputIter = new SortBasedAggregationIterator( + partIndex, groupingExpressions, child.output, iter, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index a5a444b160c6..492b0f2da77c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric * sorted by values of [[groupingExpressions]]. */ class SortBasedAggregationIterator( + partIndex: Int, groupingExpressions: Seq[NamedExpression], valueAttributes: Seq[Attribute], inputIterator: Iterator[InternalRow], @@ -37,6 +38,7 @@ class SortBasedAggregationIterator( newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, numOutputRows: SQLMetric) extends AggregationIterator( + partIndex, groupingExpressions, valueAttributes, aggregateExpressions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index cfa930607360..756eeb642e2d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -60,6 +60,8 @@ import org.apache.spark.unsafe.KVIterator * - Part 8: A utility function used to generate a result when there is no * input and there is no grouping expression. * + * @param partIndex + * index of the partition * @param groupingExpressions * expressions for grouping keys * @param aggregateExpressions @@ -77,6 +79,7 @@ import org.apache.spark.unsafe.KVIterator * the iterator containing input [[UnsafeRow]]s. */ class TungstenAggregationIterator( + partIndex: Int, groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], aggregateAttributes: Seq[Attribute], @@ -91,6 +94,7 @@ class TungstenAggregationIterator( spillSize: SQLMetric, avgHashProbe: SQLMetric) extends AggregationIterator( + partIndex, groupingExpressions, originalInputAttributes, aggregateExpressions, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 0681b9cbeb1d..fdb9f1d1e0e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -24,6 +24,8 @@ import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.execution.WholeStageCodegenExec +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -449,6 +451,49 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_)) } + private def assertNoExceptions(c: Column): Unit = { + for ((wholeStage, useObjectHashAgg) <- + Seq((true, true), (true, false), (false, true), (false, false))) { + withSQLConf( + (SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString), + (SQLConf.USE_OBJECT_HASH_AGG.key, useObjectHashAgg.toString)) { + + val df = Seq(("1", 1), ("1", 2), ("2", 3), ("2", 4)).toDF("x", "y") + + // HashAggregate test case + val hashAggDF = df.groupBy("x").agg(c, sum("y")) + val hashAggPlan = hashAggDF.queryExecution.executedPlan + if (wholeStage) { + assert(hashAggPlan.find { + case WholeStageCodegenExec(_: HashAggregateExec) => true + case _ => false + }.isDefined) + } else { + assert(hashAggPlan.isInstanceOf[HashAggregateExec]) + } + hashAggDF.collect() + + // ObjectHashAggregate and SortAggregate test case + val objHashAggOrSortAggDF = df.groupBy("x").agg(c, collect_list("y")) + val objHashAggOrSortAggPlan = objHashAggOrSortAggDF.queryExecution.executedPlan + if (useObjectHashAgg) { + assert(objHashAggOrSortAggPlan.isInstanceOf[ObjectHashAggregateExec]) + } else { + assert(objHashAggOrSortAggPlan.isInstanceOf[SortAggregateExec]) + } + objHashAggOrSortAggDF.collect() + } + } + } + + test("SPARK-19471: AggregationIterator does not initialize the generated result projection" + + " before using it") { + Seq( + monotonically_increasing_id(), spark_partition_id(), + rand(Random.nextLong()), randn(Random.nextLong()) + ).foreach(assertNoExceptions) + } + test("SPARK-21281 use string types by default if array and map have no argument") { val ds = spark.range(1) var expectedSchema = new StructType()