Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit e243ca6

Browse files
committed
Add aggregation iterators.
1 parent a101960 commit e243ca6

File tree

4 files changed

+407
-298
lines changed

4 files changed

+407
-298
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/Aggregate2Sort.scala

Lines changed: 28 additions & 294 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@ import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.expressions.aggregate2._
2525
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution}
2626
import org.apache.spark.sql.execution.{SparkPlan, UnaryNode}
27-
import org.apache.spark.sql.types.NullType
28-
29-
import scala.collection.mutable.ArrayBuffer
3027

3128
case class Aggregate2Sort(
3229
groupingExpressions: Seq[NamedExpression],
@@ -48,6 +45,16 @@ case class Aggregate2Sort(
4845
}
4946
}
5047

48+
override def references: AttributeSet = {
49+
val referencesInResults =
50+
AttributeSet(resultExpressions.flatMap(_.references)) -- AttributeSet(aggregateAttributes)
51+
52+
AttributeSet(
53+
groupingExpressions.flatMap(_.references) ++
54+
aggregateExpressions.flatMap(_.references) ++
55+
referencesInResults)
56+
}
57+
5158
override def requiredChildDistribution: List[Distribution] = {
5259
if (partialAggregation) {
5360
UnspecifiedDistribution :: Nil
@@ -67,299 +74,26 @@ case class Aggregate2Sort(
6774

6875
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
6976
child.execute().mapPartitions { iter =>
70-
71-
new Iterator[InternalRow] {
72-
// aggregateFunctions contains all of aggregate functions used by this operator.
73-
// When populating aggregateFunctions, we also set bufferOffsets for those
74-
// functions and bind references for non-algebraic aggregate functions when
75-
// the mode is Partial or Complete.
76-
private val aggregateFunctions: Array[AggregateFunction2] = {
77-
var bufferOffset =
78-
if (partialAggregation) {
79-
0
80-
} else {
81-
groupingExpressions.length
82-
}
83-
val functions = new Array[AggregateFunction2](aggregateExpressions.length)
84-
var i = 0
85-
while (i < aggregateExpressions.length) {
86-
val func = aggregateExpressions(i).aggregateFunction
87-
val funcWithBoundReferences = aggregateExpressions(i).mode match {
88-
case Partial | Complete if !func.isInstanceOf[AlgebraicAggregate] =>
89-
// We need to create BoundReferences if the function is not an
90-
// AlgebraicAggregate (it does not support code-gen) and the mode of
91-
// this function is Partial or Complete because we will call eval of this
92-
// function's children in the update method of this aggregate function.
93-
// Those eval calls require BoundReferences to work.
94-
BindReferences.bindReference(func, child.output)
95-
case _ => func
96-
}
97-
// Set bufferOffset for this function. It is important that setting bufferOffset
98-
// happens after all potential bindReference operations because bindReference
99-
// will create a new instance of the function.
100-
funcWithBoundReferences.bufferOffset = bufferOffset
101-
bufferOffset += funcWithBoundReferences.bufferSchema.length
102-
functions(i) = funcWithBoundReferences
103-
i += 1
104-
}
105-
functions
106-
}
107-
108-
// All non-algebraic aggregate functions.
109-
private val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] = {
110-
aggregateFunctions.collect {
111-
case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
112-
}.toArray
113-
}
114-
115-
// Positions of those non-algebraic aggregate functions in aggregateFunctions.
116-
// For example, we have func1, func2, func3, func4 in aggregateFunctions, and
117-
// func2 and func3 are non-algebraic aggregate functions.
118-
// nonAlgebraicAggregateFunctionPositions will be [1, 2].
119-
private val nonAlgebraicAggregateFunctionPositions: Array[Int] = {
120-
val positions = new ArrayBuffer[Int]()
121-
var i = 0
122-
while (i < aggregateFunctions.length) {
123-
aggregateFunctions(i) match {
124-
case agg: AlgebraicAggregate =>
125-
case _ => positions += i
126-
}
127-
i += 1
128-
}
129-
positions.toArray
130-
}
131-
132-
// The number of elements of the underlying buffer of this operator.
133-
// All aggregate functions are sharing this underlying buffer and they find their
134-
// buffer values through bufferOffset.
135-
private val bufferSize: Int = {
136-
var size = 0
137-
var i = 0
138-
while (i < aggregateFunctions.length) {
139-
size += aggregateFunctions(i).bufferSchema.length
140-
i += 1
141-
}
142-
if (partialAggregation) {
143-
size
144-
} else {
145-
groupingExpressions.length + size
146-
}
147-
}
148-
149-
// This is used to project expressions for the grouping expressions.
150-
protected val groupGenerator =
151-
newMutableProjection(groupingExpressions, child.output)()
152-
// The partition key of the current partition.
153-
private var currentGroupingKey: InternalRow = _
154-
// The partition key of next partition.
155-
private var nextGroupingKey: InternalRow = _
156-
// The first row of next partition.
157-
private var firstRowInNextGroup: InternalRow = _
158-
// Indicates if we has new group of rows to process.
159-
private var hasNewGroup: Boolean = true
160-
// The underlying buffer shared by all aggregate functions.
161-
private val buffer: MutableRow = new GenericMutableRow(bufferSize)
162-
// The result of aggregate functions. It is only used when aggregate functions' modes
163-
// are Final.
164-
private val aggregateResult: MutableRow = new GenericMutableRow(aggregateAttributes.length)
165-
private val joinedRow = new JoinedRow4
166-
// The projection used to generate the output rows of this operator.
167-
// This is only used when we are generating final results of aggregate functions.
168-
private lazy val resultProjection =
169-
newMutableProjection(
170-
resultExpressions, groupingExpressions.map(_.toAttribute) ++ aggregateAttributes)()
171-
172-
// When we merge buffers (for mode PartialMerge or Final), the input rows start with
173-
// values for grouping expressions. So, when we construct our buffer for this
174-
// aggregate function, the size of the buffer matches the number of values in the
175-
// input rows. To simplify the code for code-gen, we need create some dummy
176-
// attributes and expressions for these grouping expressions.
177-
private val offsetAttributes = {
178-
if (partialAggregation) {
179-
Nil
180-
} else {
181-
Seq.fill(groupingExpressions.length)(AttributeReference("offset", NullType)())
182-
}
183-
}
184-
private val offsetExpressions =
185-
if (partialAggregation) Nil else Seq.fill(groupingExpressions.length)(NoOp)
186-
187-
// This projection is used to initialize buffer values for all AlgebraicAggregates.
188-
private val algebraicInitialProjection = {
189-
val initExpressions = offsetExpressions ++ aggregateFunctions.flatMap {
190-
case ae: AlgebraicAggregate => ae.initialValues
191-
case agg: AggregateFunction2 => NoOp :: Nil
192-
}
193-
newMutableProjection(initExpressions, Nil)().target(buffer)
194-
}
195-
196-
// This projection is used to update buffer values for all AlgebraicAggregates.
197-
private lazy val algebraicUpdateProjection = {
198-
val bufferSchema = aggregateFunctions.flatMap {
199-
case ae: AlgebraicAggregate => ae.bufferAttributes
200-
case agg: AggregateFunction2 => agg.bufferAttributes
201-
}
202-
val updateExpressions = aggregateFunctions.flatMap {
203-
case ae: AlgebraicAggregate => ae.updateExpressions
204-
case agg: AggregateFunction2 => NoOp :: Nil
205-
}
206-
newMutableProjection(updateExpressions, bufferSchema ++ child.output)().target(buffer)
207-
}
208-
209-
// This projection is used to merge buffer values for all AlgebraicAggregates.
210-
private lazy val algebraicMergeProjection = {
211-
val bufferSchemata =
212-
offsetAttributes ++ aggregateFunctions.flatMap {
213-
case ae: AlgebraicAggregate => ae.bufferAttributes
214-
case agg: AggregateFunction2 => agg.bufferAttributes
215-
} ++ offsetAttributes ++ aggregateFunctions.flatMap {
216-
case ae: AlgebraicAggregate => ae.cloneBufferAttributes
217-
case agg: AggregateFunction2 => agg.cloneBufferAttributes
218-
}
219-
val mergeExpressions = offsetExpressions ++ aggregateFunctions.flatMap {
220-
case ae: AlgebraicAggregate => ae.mergeExpressions
221-
case agg: AggregateFunction2 => NoOp :: Nil
222-
}
223-
224-
newMutableProjection(mergeExpressions, bufferSchemata)()
225-
}
226-
227-
// This projection is used to evaluate all AlgebraicAggregates.
228-
private lazy val algebraicEvalProjection = {
229-
val bufferSchemata =
230-
offsetAttributes ++ aggregateFunctions.flatMap {
231-
case ae: AlgebraicAggregate => ae.bufferAttributes
232-
case agg: AggregateFunction2 => agg.bufferAttributes
233-
} ++ offsetAttributes ++ aggregateFunctions.flatMap {
234-
case ae: AlgebraicAggregate => ae.cloneBufferAttributes
235-
case agg: AggregateFunction2 => agg.cloneBufferAttributes
236-
}
237-
val evalExpressions = aggregateFunctions.map {
238-
case ae: AlgebraicAggregate => ae.evaluateExpression
239-
case agg: AggregateFunction2 => NoOp
240-
}
241-
242-
newMutableProjection(evalExpressions, bufferSchemata)()
243-
}
244-
245-
// Initialize this iterator.
246-
initialize()
247-
248-
private def initialize(): Unit = {
249-
if (iter.hasNext) {
250-
initializeBuffer()
251-
val currentRow = iter.next().copy()
252-
// partitionGenerator is a mutable projection. Since we need to track nextGroupingKey,
253-
// we are making a copy at here.
254-
nextGroupingKey = groupGenerator(currentRow).copy()
255-
firstRowInNextGroup = currentRow
256-
} else {
257-
// This iter is an empty one.
258-
hasNewGroup = false
259-
}
260-
}
261-
262-
/** Initializes buffer values for all aggregate functions. */
263-
private def initializeBuffer(): Unit = {
264-
algebraicInitialProjection(EmptyRow)
265-
var i = 0
266-
while (i < nonAlgebraicAggregateFunctions.length) {
267-
nonAlgebraicAggregateFunctions(i).initialize(buffer)
268-
i += 1
269-
}
270-
}
271-
272-
/** Processes the current input row. */
273-
private def processRow(row: InternalRow): Unit = {
274-
// The new row is still in the current group.
275-
if (partialAggregation) {
276-
algebraicUpdateProjection(joinedRow(buffer, row))
277-
var i = 0
278-
while (i < nonAlgebraicAggregateFunctions.length) {
279-
nonAlgebraicAggregateFunctions(i).update(buffer, row)
280-
i += 1
281-
}
282-
} else {
283-
algebraicMergeProjection.target(buffer)(joinedRow(buffer, row))
284-
var i = 0
285-
while (i < nonAlgebraicAggregateFunctions.length) {
286-
nonAlgebraicAggregateFunctions(i).merge(buffer, row)
287-
i += 1
288-
}
289-
}
290-
}
291-
292-
/** Processes rows in the current group. It will stop when it find a new group. */
293-
private def processCurrentGroup(): Unit = {
294-
currentGroupingKey = nextGroupingKey
295-
// Now, we will start to find all rows belonging to this group.
296-
// We create a variable to track if we see the next group.
297-
var findNextPartition = false
298-
// firstRowInNextGroup is the first row of this group. We first process it.
299-
processRow(firstRowInNextGroup)
300-
// The search will stop when we see the next group or there is no
301-
// input row left in the iter.
302-
while (iter.hasNext && !findNextPartition) {
303-
val currentRow = iter.next()
304-
// Get the grouping key based on the grouping expressions.
305-
// For the below compare method, we do not need to make a copy of groupingKey.
306-
val groupingKey = groupGenerator(currentRow)
307-
// Check if the current row belongs the current input row.
308-
currentGroupingKey.equals(groupingKey)
309-
310-
if (currentGroupingKey == groupingKey) {
311-
processRow(currentRow)
312-
} else {
313-
// We find a new group.
314-
findNextPartition = true
315-
nextGroupingKey = groupingKey.copy()
316-
firstRowInNextGroup = currentRow.copy()
317-
}
318-
}
319-
// We have not seen a new group. It means that there is no new row in the input
320-
// iter. The current group is the last group of the iter.
321-
if (!findNextPartition) {
322-
hasNewGroup = false
323-
}
324-
}
325-
326-
private def generateOutput: () => InternalRow = {
327-
if (partialAggregation) {
328-
// If it is partialAggregation, we just output the grouping columns and the buffer.
329-
() => joinedRow(currentGroupingKey, buffer).copy()
330-
} else {
331-
() => {
332-
algebraicEvalProjection.target(aggregateResult)(buffer)
333-
var i = 0
334-
while (i < nonAlgebraicAggregateFunctions.length) {
335-
aggregateResult.update(
336-
nonAlgebraicAggregateFunctionPositions(i),
337-
nonAlgebraicAggregateFunctions(i).eval(buffer))
338-
i += 1
339-
}
340-
resultProjection(joinedRow(currentGroupingKey, aggregateResult))
341-
}
342-
}
77+
val aggregationIterator =
78+
if (partialAggregation) {
79+
new PartialSortAggregationIterator(
80+
groupingExpressions,
81+
aggregateExpressions,
82+
newMutableProjection,
83+
child.output,
84+
iter)
85+
} else {
86+
new FinalSortAggregationIterator(
87+
groupingExpressions,
88+
aggregateExpressions,
89+
aggregateAttributes,
90+
resultExpressions,
91+
newMutableProjection,
92+
child.output,
93+
iter)
34394
}
34495

345-
override final def hasNext: Boolean = hasNewGroup
346-
347-
override final def next(): InternalRow = {
348-
if (hasNext) {
349-
// Process the current group.
350-
processCurrentGroup()
351-
// Generate output row for the current group.
352-
val outputRow = generateOutput()
353-
// Initilize buffer values for the next group.
354-
initializeBuffer()
355-
356-
outputRow
357-
} else {
358-
// no more result
359-
throw new NoSuchElementException
360-
}
361-
}
362-
}
96+
aggregationIterator
36397
}
36498
}
36599
}

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate2/rules.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,13 @@ case class CheckAggregateFunction(context: SQLContext) extends (LogicalPlan => U
3939
def apply(plan: LogicalPlan): Unit = plan.foreachUp {
4040
case p if context.conf.useSqlAggregate2 => p.transformExpressionsUp {
4141
case agg: AggregateExpression1 =>
42-
failAnalysis(s"${SQLConf.USE_SQL_AGGREGATE2} is enabled. Please disable it to use $agg.")
42+
failAnalysis(
43+
s"${SQLConf.USE_SQL_AGGREGATE2.key} is enabled. Please disable it to use $agg.")
4344
}
4445
case p if !context.conf.useSqlAggregate2 => p.transformExpressionsUp {
4546
case agg: AggregateExpression2 =>
46-
failAnalysis(s"${SQLConf.USE_SQL_AGGREGATE2} is disabled. Please enable it to use $agg.")
47+
failAnalysis(
48+
s"${SQLConf.USE_SQL_AGGREGATE2.key} is disabled. Please enable it to use $agg.")
4749
}
4850
}
4951
}

0 commit comments

Comments
 (0)