@@ -24,9 +24,6 @@ import org.apache.spark.sql.catalyst.expressions._
2424import org .apache .spark .sql .catalyst .expressions .aggregate2 ._
2525import org .apache .spark .sql .catalyst .plans .physical .{AllTuples , ClusteredDistribution , Distribution , UnspecifiedDistribution }
2626import org .apache .spark .sql .execution .{SparkPlan , UnaryNode }
27- import org .apache .spark .sql .types .NullType
28-
29- import scala .collection .mutable .ArrayBuffer
3027
3128case class Aggregate2Sort (
3229 groupingExpressions : Seq [NamedExpression ],
@@ -48,6 +45,16 @@ case class Aggregate2Sort(
4845 }
4946 }
5047
48+ override def references : AttributeSet = {
49+ val referencesInResults =
50+ AttributeSet (resultExpressions.flatMap(_.references)) -- AttributeSet (aggregateAttributes)
51+
52+ AttributeSet (
53+ groupingExpressions.flatMap(_.references) ++
54+ aggregateExpressions.flatMap(_.references) ++
55+ referencesInResults)
56+ }
57+
5158 override def requiredChildDistribution : List [Distribution ] = {
5259 if (partialAggregation) {
5360 UnspecifiedDistribution :: Nil
@@ -67,299 +74,26 @@ case class Aggregate2Sort(
6774
6875 protected override def doExecute (): RDD [InternalRow ] = attachTree(this , " execute" ) {
6976 child.execute().mapPartitions { iter =>
70-
71- new Iterator [InternalRow ] {
72- // aggregateFunctions contains all of aggregate functions used by this operator.
73- // When populating aggregateFunctions, we also set bufferOffsets for those
74- // functions and bind references for non-algebraic aggregate functions when
75- // the mode is Partial or Complete.
76- private val aggregateFunctions : Array [AggregateFunction2 ] = {
77- var bufferOffset =
78- if (partialAggregation) {
79- 0
80- } else {
81- groupingExpressions.length
82- }
83- val functions = new Array [AggregateFunction2 ](aggregateExpressions.length)
84- var i = 0
85- while (i < aggregateExpressions.length) {
86- val func = aggregateExpressions(i).aggregateFunction
87- val funcWithBoundReferences = aggregateExpressions(i).mode match {
88- case Partial | Complete if ! func.isInstanceOf [AlgebraicAggregate ] =>
89- // We need to create BoundReferences if the function is not an
90- // AlgebraicAggregate (it does not support code-gen) and the mode of
91- // this function is Partial or Complete because we will call eval of this
92- // function's children in the update method of this aggregate function.
93- // Those eval calls require BoundReferences to work.
94- BindReferences .bindReference(func, child.output)
95- case _ => func
96- }
97- // Set bufferOffset for this function. It is important that setting bufferOffset
98- // happens after all potential bindReference operations because bindReference
99- // will create a new instance of the function.
100- funcWithBoundReferences.bufferOffset = bufferOffset
101- bufferOffset += funcWithBoundReferences.bufferSchema.length
102- functions(i) = funcWithBoundReferences
103- i += 1
104- }
105- functions
106- }
107-
108- // All non-algebraic aggregate functions.
109- private val nonAlgebraicAggregateFunctions : Array [AggregateFunction2 ] = {
110- aggregateFunctions.collect {
111- case func : AggregateFunction2 if ! func.isInstanceOf [AlgebraicAggregate ] => func
112- }.toArray
113- }
114-
115- // Positions of those non-algebraic aggregate functions in aggregateFunctions.
116- // For example, we have func1, func2, func3, func4 in aggregateFunctions, and
117- // func2 and func3 are non-algebraic aggregate functions.
118- // nonAlgebraicAggregateFunctionPositions will be [1, 2].
119- private val nonAlgebraicAggregateFunctionPositions : Array [Int ] = {
120- val positions = new ArrayBuffer [Int ]()
121- var i = 0
122- while (i < aggregateFunctions.length) {
123- aggregateFunctions(i) match {
124- case agg : AlgebraicAggregate =>
125- case _ => positions += i
126- }
127- i += 1
128- }
129- positions.toArray
130- }
131-
132- // The number of elements of the underlying buffer of this operator.
133- // All aggregate functions are sharing this underlying buffer and they find their
134- // buffer values through bufferOffset.
135- private val bufferSize : Int = {
136- var size = 0
137- var i = 0
138- while (i < aggregateFunctions.length) {
139- size += aggregateFunctions(i).bufferSchema.length
140- i += 1
141- }
142- if (partialAggregation) {
143- size
144- } else {
145- groupingExpressions.length + size
146- }
147- }
148-
149- // This is used to project expressions for the grouping expressions.
150- protected val groupGenerator =
151- newMutableProjection(groupingExpressions, child.output)()
152- // The partition key of the current partition.
153- private var currentGroupingKey : InternalRow = _
154- // The partition key of next partition.
155- private var nextGroupingKey : InternalRow = _
156- // The first row of next partition.
157- private var firstRowInNextGroup : InternalRow = _
158- // Indicates if we has new group of rows to process.
159- private var hasNewGroup : Boolean = true
160- // The underlying buffer shared by all aggregate functions.
161- private val buffer : MutableRow = new GenericMutableRow (bufferSize)
162- // The result of aggregate functions. It is only used when aggregate functions' modes
163- // are Final.
164- private val aggregateResult : MutableRow = new GenericMutableRow (aggregateAttributes.length)
165- private val joinedRow = new JoinedRow4
166- // The projection used to generate the output rows of this operator.
167- // This is only used when we are generating final results of aggregate functions.
168- private lazy val resultProjection =
169- newMutableProjection(
170- resultExpressions, groupingExpressions.map(_.toAttribute) ++ aggregateAttributes)()
171-
172- // When we merge buffers (for mode PartialMerge or Final), the input rows start with
173- // values for grouping expressions. So, when we construct our buffer for this
174- // aggregate function, the size of the buffer matches the number of values in the
175- // input rows. To simplify the code for code-gen, we need create some dummy
176- // attributes and expressions for these grouping expressions.
177- private val offsetAttributes = {
178- if (partialAggregation) {
179- Nil
180- } else {
181- Seq .fill(groupingExpressions.length)(AttributeReference (" offset" , NullType )())
182- }
183- }
184- private val offsetExpressions =
185- if (partialAggregation) Nil else Seq .fill(groupingExpressions.length)(NoOp )
186-
187- // This projection is used to initialize buffer values for all AlgebraicAggregates.
188- private val algebraicInitialProjection = {
189- val initExpressions = offsetExpressions ++ aggregateFunctions.flatMap {
190- case ae : AlgebraicAggregate => ae.initialValues
191- case agg : AggregateFunction2 => NoOp :: Nil
192- }
193- newMutableProjection(initExpressions, Nil )().target(buffer)
194- }
195-
196- // This projection is used to update buffer values for all AlgebraicAggregates.
197- private lazy val algebraicUpdateProjection = {
198- val bufferSchema = aggregateFunctions.flatMap {
199- case ae : AlgebraicAggregate => ae.bufferAttributes
200- case agg : AggregateFunction2 => agg.bufferAttributes
201- }
202- val updateExpressions = aggregateFunctions.flatMap {
203- case ae : AlgebraicAggregate => ae.updateExpressions
204- case agg : AggregateFunction2 => NoOp :: Nil
205- }
206- newMutableProjection(updateExpressions, bufferSchema ++ child.output)().target(buffer)
207- }
208-
209- // This projection is used to merge buffer values for all AlgebraicAggregates.
210- private lazy val algebraicMergeProjection = {
211- val bufferSchemata =
212- offsetAttributes ++ aggregateFunctions.flatMap {
213- case ae : AlgebraicAggregate => ae.bufferAttributes
214- case agg : AggregateFunction2 => agg.bufferAttributes
215- } ++ offsetAttributes ++ aggregateFunctions.flatMap {
216- case ae : AlgebraicAggregate => ae.cloneBufferAttributes
217- case agg : AggregateFunction2 => agg.cloneBufferAttributes
218- }
219- val mergeExpressions = offsetExpressions ++ aggregateFunctions.flatMap {
220- case ae : AlgebraicAggregate => ae.mergeExpressions
221- case agg : AggregateFunction2 => NoOp :: Nil
222- }
223-
224- newMutableProjection(mergeExpressions, bufferSchemata)()
225- }
226-
227- // This projection is used to evaluate all AlgebraicAggregates.
228- private lazy val algebraicEvalProjection = {
229- val bufferSchemata =
230- offsetAttributes ++ aggregateFunctions.flatMap {
231- case ae : AlgebraicAggregate => ae.bufferAttributes
232- case agg : AggregateFunction2 => agg.bufferAttributes
233- } ++ offsetAttributes ++ aggregateFunctions.flatMap {
234- case ae : AlgebraicAggregate => ae.cloneBufferAttributes
235- case agg : AggregateFunction2 => agg.cloneBufferAttributes
236- }
237- val evalExpressions = aggregateFunctions.map {
238- case ae : AlgebraicAggregate => ae.evaluateExpression
239- case agg : AggregateFunction2 => NoOp
240- }
241-
242- newMutableProjection(evalExpressions, bufferSchemata)()
243- }
244-
245- // Initialize this iterator.
246- initialize()
247-
248- private def initialize (): Unit = {
249- if (iter.hasNext) {
250- initializeBuffer()
251- val currentRow = iter.next().copy()
252- // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey,
253- // we are making a copy at here.
254- nextGroupingKey = groupGenerator(currentRow).copy()
255- firstRowInNextGroup = currentRow
256- } else {
257- // This iter is an empty one.
258- hasNewGroup = false
259- }
260- }
261-
262- /** Initializes buffer values for all aggregate functions. */
263- private def initializeBuffer (): Unit = {
264- algebraicInitialProjection(EmptyRow )
265- var i = 0
266- while (i < nonAlgebraicAggregateFunctions.length) {
267- nonAlgebraicAggregateFunctions(i).initialize(buffer)
268- i += 1
269- }
270- }
271-
272- /** Processes the current input row. */
273- private def processRow (row : InternalRow ): Unit = {
274- // The new row is still in the current group.
275- if (partialAggregation) {
276- algebraicUpdateProjection(joinedRow(buffer, row))
277- var i = 0
278- while (i < nonAlgebraicAggregateFunctions.length) {
279- nonAlgebraicAggregateFunctions(i).update(buffer, row)
280- i += 1
281- }
282- } else {
283- algebraicMergeProjection.target(buffer)(joinedRow(buffer, row))
284- var i = 0
285- while (i < nonAlgebraicAggregateFunctions.length) {
286- nonAlgebraicAggregateFunctions(i).merge(buffer, row)
287- i += 1
288- }
289- }
290- }
291-
292- /** Processes rows in the current group. It will stop when it find a new group. */
293- private def processCurrentGroup (): Unit = {
294- currentGroupingKey = nextGroupingKey
295- // Now, we will start to find all rows belonging to this group.
296- // We create a variable to track if we see the next group.
297- var findNextPartition = false
298- // firstRowInNextGroup is the first row of this group. We first process it.
299- processRow(firstRowInNextGroup)
300- // The search will stop when we see the next group or there is no
301- // input row left in the iter.
302- while (iter.hasNext && ! findNextPartition) {
303- val currentRow = iter.next()
304- // Get the grouping key based on the grouping expressions.
305- // For the below compare method, we do not need to make a copy of groupingKey.
306- val groupingKey = groupGenerator(currentRow)
307- // Check if the current row belongs the current input row.
308- currentGroupingKey.equals(groupingKey)
309-
310- if (currentGroupingKey == groupingKey) {
311- processRow(currentRow)
312- } else {
313- // We find a new group.
314- findNextPartition = true
315- nextGroupingKey = groupingKey.copy()
316- firstRowInNextGroup = currentRow.copy()
317- }
318- }
319- // We have not seen a new group. It means that there is no new row in the input
320- // iter. The current group is the last group of the iter.
321- if (! findNextPartition) {
322- hasNewGroup = false
323- }
324- }
325-
326- private def generateOutput : () => InternalRow = {
327- if (partialAggregation) {
328- // If it is partialAggregation, we just output the grouping columns and the buffer.
329- () => joinedRow(currentGroupingKey, buffer).copy()
330- } else {
331- () => {
332- algebraicEvalProjection.target(aggregateResult)(buffer)
333- var i = 0
334- while (i < nonAlgebraicAggregateFunctions.length) {
335- aggregateResult.update(
336- nonAlgebraicAggregateFunctionPositions(i),
337- nonAlgebraicAggregateFunctions(i).eval(buffer))
338- i += 1
339- }
340- resultProjection(joinedRow(currentGroupingKey, aggregateResult))
341- }
342- }
77+ val aggregationIterator =
78+ if (partialAggregation) {
79+ new PartialSortAggregationIterator (
80+ groupingExpressions,
81+ aggregateExpressions,
82+ newMutableProjection,
83+ child.output,
84+ iter)
85+ } else {
86+ new FinalSortAggregationIterator (
87+ groupingExpressions,
88+ aggregateExpressions,
89+ aggregateAttributes,
90+ resultExpressions,
91+ newMutableProjection,
92+ child.output,
93+ iter)
34394 }
34495
345- override final def hasNext : Boolean = hasNewGroup
346-
347- override final def next (): InternalRow = {
348- if (hasNext) {
349- // Process the current group.
350- processCurrentGroup()
351- // Generate output row for the current group.
352- val outputRow = generateOutput()
353- // Initilize buffer values for the next group.
354- initializeBuffer()
355-
356- outputRow
357- } else {
358- // no more result
359- throw new NoSuchElementException
360- }
361- }
362- }
96+ aggregationIterator
36397 }
36498 }
36599}
0 commit comments