diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 7c11fdb9792e..28d2055aef22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ * is used to generate result. */ abstract class AggregationIterator( + partIndex: Int, groupingExpressions: Seq[NamedExpression], inputAttributes: Seq[Attribute], aggregateExpressions: Seq[AggregateExpression], @@ -229,6 +230,7 @@ abstract class AggregationIterator( allImperativeAggregateFunctions(i).eval(currentBuffer)) i += 1 } + resultProjection.initialize(partIndex) resultProjection(joinedRow(currentGroupingKey, aggregateResult)) } } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { @@ -251,12 +253,14 @@ abstract class AggregationIterator( typedImperativeAggregates(i).serializeAggregateBufferInPlace(currentBuffer) i += 1 } + resultProjection.initialize(partIndex) resultProjection(joinedRow(currentGroupingKey, currentBuffer)) } } else { // Grouping-only: we only output values based on grouping expressions. val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes) (currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => { + resultProjection.initialize(partIndex) resultProjection(currentGroupingKey) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 68c8e6ce62cb..81d2ceb0697e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -94,7 +94,7 @@ case class HashAggregateExec( val peakMemory = longMetric("peakMemory") val spillSize = longMetric("spillSize") - child.execute().mapPartitions { iter => + child.execute().mapPartitionsWithIndex { (partIndex, iter) => val hasInput = iter.hasNext if (!hasInput && groupingExpressions.nonEmpty) { @@ -104,6 +104,7 @@ case class HashAggregateExec( } else { val aggregationIterator = new TungstenAggregationIterator( + partIndex, groupingExpressions, aggregateExpressions, aggregateAttributes, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index 6e47f9d61119..40fccef69c05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -31,6 +31,7 @@ import org.apache.spark.unsafe.KVIterator import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter class ObjectAggregationIterator( + partIndex: Int, outputAttributes: Seq[Attribute], groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], @@ -43,6 +44,7 @@ class ObjectAggregationIterator( fallbackCountThreshold: Int, numOutputRows: SQLMetric) extends AggregationIterator( + partIndex, groupingExpressions, originalInputAttributes, aggregateExpressions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index b53521b1b6ba..6316e06a8f34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -98,7 +98,7 @@ case class ObjectHashAggregateExec( val numOutputRows = longMetric("numOutputRows") val fallbackCountThreshold = sqlContext.conf.objectAggSortBasedFallbackThreshold - child.execute().mapPartitionsInternal { iter => + child.execute().mapPartitionsWithIndexInternal { (partIndex, iter) => val hasInput = iter.hasNext if (!hasInput && groupingExpressions.nonEmpty) { // This is a grouped aggregate and the input kvIterator is empty, @@ -107,6 +107,7 @@ case class ObjectHashAggregateExec( } else { val aggregationIterator = new ObjectAggregationIterator( + partIndex, child.output, groupingExpressions, aggregateExpressions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index be3198b8e7d8..a43235790834 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -74,7 +74,7 @@ case class SortAggregateExec( protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { val numOutputRows = longMetric("numOutputRows") - child.execute().mapPartitionsInternal { iter => + child.execute().mapPartitionsWithIndexInternal { (partIndex, iter) => // Because the constructor of an aggregation iterator will read at least the first row, // we need to get the value of iter.hasNext first. val hasInput = iter.hasNext @@ -84,6 +84,7 @@ case class SortAggregateExec( Iterator[UnsafeRow]() } else { val outputIter = new SortBasedAggregationIterator( + partIndex, groupingExpressions, child.output, iter, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index bea2dce1a765..0b255c52448a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric * sorted by values of [[groupingExpressions]]. */ class SortBasedAggregationIterator( + partIndex: Int, groupingExpressions: Seq[NamedExpression], valueAttributes: Seq[Attribute], inputIterator: Iterator[InternalRow], @@ -37,6 +38,7 @@ class SortBasedAggregationIterator( newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, numOutputRows: SQLMetric) extends AggregationIterator( + partIndex, groupingExpressions, valueAttributes, aggregateExpressions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 2988161ee5e7..a1f11dd9094b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -77,6 +77,7 @@ import org.apache.spark.unsafe.KVIterator * the iterator containing input [[UnsafeRow]]s. */ class TungstenAggregationIterator( + partIndex: Int, groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], aggregateAttributes: Seq[Attribute], @@ -90,6 +91,7 @@ class TungstenAggregationIterator( peakMemory: SQLMetric, spillSize: SQLMetric) extends AggregationIterator( + partIndex, groupingExpressions, originalInputAttributes, aggregateExpressions, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 0e9a2c6cf7de..88fcd3c50640 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -448,6 +448,28 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { rand(Random.nextLong()), randn(Random.nextLong()) ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_)) } + + private def assertNoExceptions(c: Column): Unit = { + for ((wholeStage, useObjectHashAgg) <- Seq((true, false), (false, false), (false, true))) { + withSQLConf( + (SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString), + (SQLConf.USE_OBJECT_HASH_AGG.key, useObjectHashAgg.toString)) { + val df = Seq(("1", 1), ("1", 2), ("2", 3), ("2", 4)).toDF("x", "y") + // HashAggregate + df.groupBy("x").agg(c, sum("y")).collect() + // ObjectHashAggregate and SortAggregate + df.groupBy("x").agg(c, collect_list("y")).collect() + } + } + } + + test("SPARK-19471: AggregationIterator does not initialize the generated result projection" + + " before using it") { + Seq( + monotonically_increasing_id(), spark_partition_id(), + rand(Random.nextLong()), randn(Random.nextLong()) + ).foreach(assertNoExceptions(_)) + } } object DataFrameFunctionsSuite {