Skip to content

Commit 338efee

Browse files
Ngone51cloud-fan
authored andcommitted
[SPARK-32031][SQL] Fix the wrong references of the PartialMerge/Final AggregateExpression
### What changes were proposed in this pull request? This PR changes the references of the `PartialMerge`/`Final` `AggregateExpression` from `aggBufferAttributes` to `inputAggBufferAttributes`. After this change, the tests of `SPARK-31620` can fail on the assertion of `QueryTest.assertEmptyMissingInput`. So, this PR also fixes it by overriding the `inputAggBufferAttributes` of the Aggregate operators. ### Why are the changes needed? With my understanding of Aggregate framework, especially, according to the logic of `AggUtils.planAggXXX`, I think for the `PartialMerge`/`Final` `AggregateExpression` the right references should be `inputAggBufferAttributes`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Before this patch, for an Aggregate operator, its input attributes will always be equal to or more than(because it refers to its own attributes while it should refer to the attributes from the child) its reference attributes. Therefore, its missing inputs must always be empty and break nothing. Thus, it's impossible to add a UT for this patch. However, after correcting the right references in this PR, the problem is then exposed by `QueryTest.assertEmptyMissingInput` in the UT of SPARK-31620, since missing inputs are no longer always empty. This PR can fix the problem. Closes #28869 from Ngone51/fix-agg-reference. Authored-by: yi.wu <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 6293c38 commit 338efee

File tree

5 files changed

+25
-35
lines changed

5 files changed

+25
-35
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ case class AggregateExpression(
140140
override lazy val references: AttributeSet = {
141141
val aggAttributes = mode match {
142142
case Partial | Complete => aggregateFunction.references
143-
case PartialMerge | Final => AttributeSet(aggregateFunction.aggBufferAttributes)
143+
case PartialMerge | Final => AttributeSet(aggregateFunction.inputAggBufferAttributes)
144144
}
145145
aggAttributes ++ filterAttributes
146146
}

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/BaseAggregateExec.scala

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.execution.aggregate
1919

20-
import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression}
20+
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, NamedExpression}
2121
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Final, PartialMerge}
2222
import org.apache.spark.sql.execution.{ExplainUtils, UnaryExecNode}
2323

@@ -53,15 +53,32 @@ trait BaseAggregateExec extends UnaryExecNode {
5353
// can't bind the `mergeExpressions` with the output of the partial aggregate, as they use
5454
// the `inputAggBufferAttributes` of the original `DeclarativeAggregate` before copy. Instead,
5555
// we shall use `inputAggBufferAttributes` after copy to match the new `mergeExpressions`.
56-
val aggAttrs = aggregateExpressions
57-
// there're exactly four cases needs `inputAggBufferAttributes` from child according to the
58-
// agg planning in `AggUtils`: Partial -> Final, PartialMerge -> Final,
59-
// Partial -> PartialMerge, PartialMerge -> PartialMerge.
60-
.filter(a => a.mode == Final || a.mode == PartialMerge).map(_.aggregateFunction)
61-
.flatMap(_.inputAggBufferAttributes)
56+
val aggAttrs = inputAggBufferAttributes
6257
child.output.dropRight(aggAttrs.length) ++ aggAttrs
6358
} else {
6459
child.output
6560
}
6661
}
62+
63+
private val inputAggBufferAttributes: Seq[Attribute] = {
64+
aggregateExpressions
65+
// there're exactly four cases needs `inputAggBufferAttributes` from child according to the
66+
// agg planning in `AggUtils`: Partial -> Final, PartialMerge -> Final,
67+
// Partial -> PartialMerge, PartialMerge -> PartialMerge.
68+
.filter(a => a.mode == Final || a.mode == PartialMerge)
69+
.flatMap(_.aggregateFunction.inputAggBufferAttributes)
70+
}
71+
72+
protected val aggregateBufferAttributes: Seq[AttributeReference] = {
73+
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
74+
}
75+
76+
override def producedAttributes: AttributeSet =
77+
AttributeSet(aggregateAttributes) ++
78+
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
79+
AttributeSet(aggregateBufferAttributes) ++
80+
// it's not empty when the inputAggBufferAttributes is not equal to the aggregate buffer
81+
// attributes of the child Aggregate, when the child Aggregate contains the subquery in
82+
// AggregateFunction. See SPARK-31620 for more details.
83+
AttributeSet(inputAggBufferAttributes.filterNot(child.output.contains))
6784
}

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,6 @@ case class HashAggregateExec(
5757
with BlockingOperatorWithCodegen
5858
with AliasAwareOutputPartitioning {
5959

60-
private[this] val aggregateBufferAttributes = {
61-
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
62-
}
63-
6460
require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes))
6561

6662
override lazy val allAttributes: AttributeSeq =
@@ -79,11 +75,6 @@ case class HashAggregateExec(
7975

8076
override protected def outputExpressions: Seq[NamedExpression] = resultExpressions
8177

82-
override def producedAttributes: AttributeSet =
83-
AttributeSet(aggregateAttributes) ++
84-
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
85-
AttributeSet(aggregateBufferAttributes)
86-
8778
override def requiredChildDistribution: List[Distribution] = {
8879
requiredChildDistributionExpressions match {
8980
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,6 @@ case class ObjectHashAggregateExec(
6969
child: SparkPlan)
7070
extends BaseAggregateExec with AliasAwareOutputPartitioning {
7171

72-
private[this] val aggregateBufferAttributes = {
73-
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
74-
}
75-
7672
override lazy val allAttributes: AttributeSeq =
7773
child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++
7874
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
@@ -84,11 +80,6 @@ case class ObjectHashAggregateExec(
8480

8581
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
8682

87-
override def producedAttributes: AttributeSet =
88-
AttributeSet(aggregateAttributes) ++
89-
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
90-
AttributeSet(aggregateBufferAttributes)
91-
9283
override def requiredChildDistribution: List[Distribution] = {
9384
requiredChildDistributionExpressions match {
9485
case Some(exprs) if exprs.isEmpty => AllTuples :: Nil

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,6 @@ case class SortAggregateExec(
4242
with AliasAwareOutputPartitioning
4343
with AliasAwareOutputOrdering {
4444

45-
private[this] val aggregateBufferAttributes = {
46-
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
47-
}
48-
49-
override def producedAttributes: AttributeSet =
50-
AttributeSet(aggregateAttributes) ++
51-
AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
52-
AttributeSet(aggregateBufferAttributes)
53-
5445
override lazy val metrics = Map(
5546
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
5647

0 commit comments

Comments
 (0)