Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
* is used to generate result.
*/
abstract class AggregationIterator(
partIndex: Int,
groupingExpressions: Seq[NamedExpression],
inputAttributes: Seq[Attribute],
aggregateExpressions: Seq[AggregateExpression],
Expand Down Expand Up @@ -217,6 +218,7 @@ abstract class AggregationIterator(

val resultProjection =
UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateAttributes)
resultProjection.initialize(partIndex)

(currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => {
// Generate results for all expression-based aggregate functions.
Expand All @@ -235,6 +237,7 @@ abstract class AggregationIterator(
val resultProjection = UnsafeProjection.create(
groupingAttributes ++ bufferAttributes,
groupingAttributes ++ bufferAttributes)
resultProjection.initialize(partIndex)

// TypedImperativeAggregate stores generic object in aggregation buffer, and requires
// calling serialization before shuffling. See [[TypedImperativeAggregate]] for more info.
Expand All @@ -256,6 +259,7 @@ abstract class AggregationIterator(
} else {
// Grouping-only: we only output values based on grouping expressions.
val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes)
resultProjection.initialize(partIndex)
(currentGroupingKey: UnsafeRow, currentBuffer: InternalRow) => {
resultProjection(currentGroupingKey)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ case class HashAggregateExec(
val spillSize = longMetric("spillSize")
val avgHashProbe = longMetric("avgHashProbe")

child.execute().mapPartitions { iter =>
child.execute().mapPartitionsWithIndex { (partIndex, iter) =>

val hasInput = iter.hasNext
if (!hasInput && groupingExpressions.nonEmpty) {
Expand All @@ -106,6 +106,7 @@ case class HashAggregateExec(
} else {
val aggregationIterator =
new TungstenAggregationIterator(
partIndex,
groupingExpressions,
aggregateExpressions,
aggregateAttributes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.unsafe.KVIterator
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter

class ObjectAggregationIterator(
partIndex: Int,
outputAttributes: Seq[Attribute],
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
Expand All @@ -43,6 +44,7 @@ class ObjectAggregationIterator(
fallbackCountThreshold: Int,
numOutputRows: SQLMetric)
extends AggregationIterator(
partIndex,
groupingExpressions,
originalInputAttributes,
aggregateExpressions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ case class ObjectHashAggregateExec(
val numOutputRows = longMetric("numOutputRows")
val fallbackCountThreshold = sqlContext.conf.objectAggSortBasedFallbackThreshold

child.execute().mapPartitionsInternal { iter =>
child.execute().mapPartitionsWithIndexInternal { (partIndex, iter) =>
val hasInput = iter.hasNext
if (!hasInput && groupingExpressions.nonEmpty) {
// This is a grouped aggregate and the input kvIterator is empty,
Expand All @@ -107,6 +107,7 @@ case class ObjectHashAggregateExec(
} else {
val aggregationIterator =
new ObjectAggregationIterator(
partIndex,
child.output,
groupingExpressions,
aggregateExpressions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ case class SortAggregateExec(

protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
val numOutputRows = longMetric("numOutputRows")
child.execute().mapPartitionsInternal { iter =>
child.execute().mapPartitionsWithIndexInternal { (partIndex, iter) =>
// Because the constructor of an aggregation iterator will read at least the first row,
// we need to get the value of iter.hasNext first.
val hasInput = iter.hasNext
Expand All @@ -84,6 +84,7 @@ case class SortAggregateExec(
Iterator[UnsafeRow]()
} else {
val outputIter = new SortBasedAggregationIterator(
partIndex,
groupingExpressions,
child.output,
iter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric
* sorted by values of [[groupingExpressions]].
*/
class SortBasedAggregationIterator(
partIndex: Int,
groupingExpressions: Seq[NamedExpression],
valueAttributes: Seq[Attribute],
inputIterator: Iterator[InternalRow],
Expand All @@ -37,6 +38,7 @@ class SortBasedAggregationIterator(
newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection,
numOutputRows: SQLMetric)
extends AggregationIterator(
partIndex,
groupingExpressions,
valueAttributes,
aggregateExpressions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ import org.apache.spark.unsafe.KVIterator
* - Part 8: A utility function used to generate a result when there is no
* input and there is no grouping expression.
*
* @param partIndex
* index of the partition
* @param groupingExpressions
* expressions for grouping keys
* @param aggregateExpressions
Expand All @@ -77,6 +79,7 @@ import org.apache.spark.unsafe.KVIterator
* the iterator containing input [[UnsafeRow]]s.
*/
class TungstenAggregationIterator(
partIndex: Int,
groupingExpressions: Seq[NamedExpression],
aggregateExpressions: Seq[AggregateExpression],
aggregateAttributes: Seq[Attribute],
Expand All @@ -91,6 +94,7 @@ class TungstenAggregationIterator(
spillSize: SQLMetric,
avgHashProbe: SQLMetric)
extends AggregationIterator(
partIndex,
groupingExpressions,
originalInputAttributes,
aggregateExpressions,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import scala.util.Random
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
Expand Down Expand Up @@ -449,6 +451,49 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_))
}

private def assertNoExceptions(c: Column): Unit = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you submit a follow-up PR to move this test case to DataFrameAggregateSuite? Thanks!

for ((wholeStage, useObjectHashAgg) <-
Seq((true, true), (true, false), (false, true), (false, false))) {
withSQLConf(
(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString),
(SQLConf.USE_OBJECT_HASH_AGG.key, useObjectHashAgg.toString)) {

val df = Seq(("1", 1), ("1", 2), ("2", 3), ("2", 4)).toDF("x", "y")

// HashAggregate test case
val hashAggDF = df.groupBy("x").agg(c, sum("y"))
val hashAggPlan = hashAggDF.queryExecution.executedPlan
if (wholeStage) {
assert(hashAggPlan.find {
case WholeStageCodegenExec(_: HashAggregateExec) => true
case _ => false
}.isDefined)
} else {
assert(hashAggPlan.isInstanceOf[HashAggregateExec])
}
hashAggDF.collect()

// ObjectHashAggregate and SortAggregate test case
val objHashAggOrSortAggDF = df.groupBy("x").agg(c, collect_list("y"))
val objHashAggOrSortAggPlan = objHashAggOrSortAggDF.queryExecution.executedPlan
if (useObjectHashAgg) {
assert(objHashAggOrSortAggPlan.isInstanceOf[ObjectHashAggregateExec])
} else {
assert(objHashAggOrSortAggPlan.isInstanceOf[SortAggregateExec])
}
objHashAggOrSortAggDF.collect()
}
}
}

test("SPARK-19471: AggregationIterator does not initialize the generated result projection" +
" before using it") {
Seq(
monotonically_increasing_id(), spark_partition_id(),
rand(Random.nextLong()), randn(Random.nextLong())
).foreach(assertNoExceptions)
}

test("SPARK-21281 use string types by default if array and map have no argument") {
val ds = spark.range(1)
var expectedSchema = new StructType()
Expand Down