Skip to content

Commit 86bd37f

Browse files
xuanyuankingcloud-fan
authored andcommitted
[SPARK-31663][SQL] Grouping sets with having clause returns the wrong result
### What changes were proposed in this pull request? - Resolve the havingcondition with expanding the GROUPING SETS/CUBE/ROLLUP expressions together in `ResolveGroupingAnalytics`: - Change the operations resolving directions to top-down. - Try resolving the condition of the filter as though it is in the aggregate clause by reusing the function in `ResolveAggregateFunctions` - Push the aggregate expressions into the aggregate which contains the expanded operations. - Use UnresolvedHaving for all having clause. ### Why are the changes needed? Correctness bug fix. See the demo and analysis in SPARK-31663. ### Does this PR introduce _any_ user-facing change? Yes, correctness bug fix for HAVING with GROUPING SETS. ### How was this patch tested? New UTs added. Closes #28501 from xuanyuanking/SPARK-31663. Authored-by: Yuanjian Li <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 7ab167a commit 86bd37f

File tree

6 files changed

+145
-39
lines changed

6 files changed

+145
-39
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 106 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -533,18 +533,13 @@ class Analyzer(
533533
}.asInstanceOf[NamedExpression]
534534
}
535535

536-
/*
537-
* Construct [[Aggregate]] operator from Cube/Rollup/GroupingSets.
538-
*/
539-
private def constructAggregate(
536+
private def getFinalGroupByExpressions(
540537
selectedGroupByExprs: Seq[Seq[Expression]],
541-
groupByExprs: Seq[Expression],
542-
aggregationExprs: Seq[NamedExpression],
543-
child: LogicalPlan): LogicalPlan = {
538+
groupByExprs: Seq[Expression]): Seq[Expression] = {
544539
// In case of ANSI-SQL compliant syntax for GROUPING SETS, groupByExprs is optional and
545540
// can be null. In such case, we derive the groupByExprs from the user supplied values for
546541
// grouping sets.
547-
val finalGroupByExpressions = if (groupByExprs == Nil) {
542+
if (groupByExprs == Nil) {
548543
selectedGroupByExprs.flatten.foldLeft(Seq.empty[Expression]) { (result, currentExpr) =>
549544
// Only unique expressions are included in the group by expressions and is determined
550545
// based on their semantic equality. Example. grouping sets ((a * b), (b * a)) results
@@ -558,6 +553,17 @@ class Analyzer(
558553
} else {
559554
groupByExprs
560555
}
556+
}
557+
558+
/*
559+
* Construct [[Aggregate]] operator from Cube/Rollup/GroupingSets.
560+
*/
561+
private def constructAggregate(
562+
selectedGroupByExprs: Seq[Seq[Expression]],
563+
groupByExprs: Seq[Expression],
564+
aggregationExprs: Seq[NamedExpression],
565+
child: LogicalPlan): LogicalPlan = {
566+
val finalGroupByExpressions = getFinalGroupByExpressions(selectedGroupByExprs, groupByExprs)
561567

562568
if (finalGroupByExpressions.size > GroupingID.dataType.defaultSize * 8) {
563569
throw new AnalysisException(
@@ -595,8 +601,70 @@ class Analyzer(
595601
}
596602
}
597603

598-
// This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort
599-
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
604+
private def tryResolveHavingCondition(h: UnresolvedHaving): LogicalPlan = {
605+
val aggForResolving = h.child match {
606+
// For CUBE/ROLLUP expressions, to avoid resolving repeatedly, here we delete them from
607+
// groupingExpressions for condition resolving.
608+
case a @ Aggregate(Seq(c @ Cube(groupByExprs)), _, _) =>
609+
a.copy(groupingExpressions = groupByExprs)
610+
case a @ Aggregate(Seq(r @ Rollup(groupByExprs)), _, _) =>
611+
a.copy(groupingExpressions = groupByExprs)
612+
case g: GroupingSets =>
613+
Aggregate(
614+
getFinalGroupByExpressions(g.selectedGroupByExprs, g.groupByExprs),
615+
g.aggregations, g.child)
616+
}
617+
// Try resolving the condition of the filter as though it is in the aggregate clause
618+
val resolvedInfo =
619+
ResolveAggregateFunctions.resolveFilterCondInAggregate(h.havingCondition, aggForResolving)
620+
621+
// Push the aggregate expressions into the aggregate (if any).
622+
if (resolvedInfo.nonEmpty) {
623+
val (extraAggExprs, resolvedHavingCond) = resolvedInfo.get
624+
val newChild = h.child match {
625+
case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) =>
626+
constructAggregate(
627+
cubeExprs(groupByExprs), groupByExprs, aggregateExpressions ++ extraAggExprs, child)
628+
case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) =>
629+
constructAggregate(
630+
rollupExprs(groupByExprs), groupByExprs, aggregateExpressions ++ extraAggExprs, child)
631+
case x: GroupingSets =>
632+
constructAggregate(
633+
x.selectedGroupByExprs, x.groupByExprs, x.aggregations ++ extraAggExprs, x.child)
634+
}
635+
636+
// Since the exprId of extraAggExprs will be changed in the constructed aggregate, and the
637+
// aggregateExpressions keeps the input order. So here we build an exprMap to resolve the
638+
// condition again.
639+
val exprMap = extraAggExprs.zip(
640+
newChild.asInstanceOf[Aggregate].aggregateExpressions.takeRight(
641+
extraAggExprs.length)).toMap
642+
val newCond = resolvedHavingCond.transform {
643+
case ne: NamedExpression if exprMap.contains(ne) => exprMap(ne)
644+
}
645+
Project(newChild.output.dropRight(extraAggExprs.length),
646+
Filter(newCond, newChild))
647+
} else {
648+
h
649+
}
650+
}
651+
652+
// This require transformDown to resolve having condition when generating aggregate node for
653+
// CUBE/ROLLUP/GROUPING SETS. This also replace grouping()/grouping_id() in resolved
654+
// Filter/Sort.
655+
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown {
656+
case h @ UnresolvedHaving(
657+
_, agg @ Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, _))
658+
if agg.childrenResolved && (groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
659+
tryResolveHavingCondition(h)
660+
case h @ UnresolvedHaving(
661+
_, agg @ Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, _))
662+
if agg.childrenResolved && (groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
663+
tryResolveHavingCondition(h)
664+
case h @ UnresolvedHaving(_, g: GroupingSets)
665+
if g.childrenResolved && g.expressions.forall(_.resolved) =>
666+
tryResolveHavingCondition(h)
667+
600668
case a if !a.childrenResolved => a // be sure all of the children are resolved.
601669

602670
// Ensure group by expressions and aggregate expressions have been resolved.
@@ -1404,7 +1472,7 @@ class Analyzer(
14041472
}
14051473

14061474
// Skip the having clause here, this will be handled in ResolveAggregateFunctions.
1407-
case h: AggregateWithHaving => h
1475+
case h: UnresolvedHaving => h
14081476

14091477
case q: LogicalPlan =>
14101478
logTrace(s"Attempting to resolve ${q.simpleString(SQLConf.get.maxToStringFields)}")
@@ -2049,7 +2117,7 @@ class Analyzer(
20492117
// Resolve aggregate with having clause to Filter(..., Aggregate()). Note, to avoid wrongly
20502118
// resolve the having condition expression, here we skip resolving it in ResolveReferences
20512119
// and transform it to Filter after aggregate is resolved. See more details in SPARK-31519.
2052-
case AggregateWithHaving(cond, agg: Aggregate) if agg.resolved =>
2120+
case UnresolvedHaving(cond, agg: Aggregate) if agg.resolved =>
20532121
resolveHaving(Filter(cond, agg), agg)
20542122

20552123
case f @ Filter(_, agg: Aggregate) if agg.resolved =>
@@ -2125,13 +2193,13 @@ class Analyzer(
21252193
condition.find(_.isInstanceOf[AggregateExpression]).isDefined
21262194
}
21272195

2128-
def resolveHaving(filter: Filter, agg: Aggregate): LogicalPlan = {
2129-
// Try resolving the condition of the filter as though it is in the aggregate clause
2196+
def resolveFilterCondInAggregate(
2197+
filterCond: Expression, agg: Aggregate): Option[(Seq[NamedExpression], Expression)] = {
21302198
try {
21312199
val aggregatedCondition =
21322200
Aggregate(
21332201
agg.groupingExpressions,
2134-
Alias(filter.condition, "havingCondition")() :: Nil,
2202+
Alias(filterCond, "havingCondition")() :: Nil,
21352203
agg.child)
21362204
val resolvedOperator = executeSameContext(aggregatedCondition)
21372205
def resolvedAggregateFilter =
@@ -2163,22 +2231,33 @@ class Analyzer(
21632231
alias.toAttribute
21642232
}
21652233
}
2166-
2167-
// Push the aggregate expressions into the aggregate (if any).
21682234
if (aggregateExpressions.nonEmpty) {
2169-
Project(agg.output,
2170-
Filter(transformedAggregateFilter,
2171-
agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions)))
2235+
Some(aggregateExpressions, transformedAggregateFilter)
21722236
} else {
2173-
filter
2237+
None
21742238
}
21752239
} else {
2176-
filter
2240+
None
21772241
}
21782242
} catch {
2179-
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
2180-
// just return the original plan.
2181-
case ae: AnalysisException => filter
2243+
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
2244+
// just return None and the caller side will return the original plan.
2245+
case ae: AnalysisException => None
2246+
}
2247+
}
2248+
2249+
def resolveHaving(filter: Filter, agg: Aggregate): LogicalPlan = {
2250+
// Try resolving the condition of the filter as though it is in the aggregate clause
2251+
val resolvedInfo = resolveFilterCondInAggregate(filter.condition, agg)
2252+
2253+
// Push the aggregate expressions into the aggregate (if any).
2254+
if (resolvedInfo.nonEmpty) {
2255+
val (aggregateExpressions, resolvedHavingCond) = resolvedInfo.get
2256+
Project(agg.output,
2257+
Filter(resolvedHavingCond,
2258+
agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions)))
2259+
} else {
2260+
filter
21822261
}
21832262
}
21842263
}
@@ -2607,12 +2686,12 @@ class Analyzer(
26072686
case Filter(condition, _) if hasWindowFunction(condition) =>
26082687
failAnalysis("It is not allowed to use window functions inside WHERE clause")
26092688

2610-
case AggregateWithHaving(condition, _) if hasWindowFunction(condition) =>
2689+
case UnresolvedHaving(condition, _) if hasWindowFunction(condition) =>
26112690
failAnalysis("It is not allowed to use window functions inside HAVING clause")
26122691

26132692
// Aggregate with Having clause. This rule works with an unresolved Aggregate because
26142693
// a resolved Aggregate will not have Window Functions.
2615-
case f @ AggregateWithHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
2694+
case f @ UnresolvedHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
26162695
if child.resolved &&
26172696
hasWindowFunction(aggregateExprs) &&
26182697
a.expressions.forall(_.resolved) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -540,11 +540,12 @@ case class UnresolvedOrdinal(ordinal: Int)
540540
}
541541

542542
/**
543-
* Represents unresolved aggregate with having clause, it is turned by the analyzer into a Filter.
543+
* Represents unresolved having clause, the child for it can be Aggregate, GroupingSets, Rollup
544+
* and Cube. It is turned by the analyzer into a Filter.
544545
*/
545-
case class AggregateWithHaving(
546+
case class UnresolvedHaving(
546547
havingCondition: Expression,
547-
child: Aggregate)
548+
child: LogicalPlan)
548549
extends UnaryNode {
549550
override lazy val resolved: Boolean = false
550551
override def output: Seq[Attribute] = child.output

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ package object dsl {
368368
groupingExprs: Expression*)(
369369
aggregateExprs: Expression*)(
370370
havingCondition: Expression): LogicalPlan = {
371-
AggregateWithHaving(havingCondition,
371+
UnresolvedHaving(havingCondition,
372372
groupBy(groupingExprs: _*)(aggregateExprs: _*).asInstanceOf[Aggregate])
373373
}
374374

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -629,12 +629,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
629629
case p: Predicate => p
630630
case e => Cast(e, BooleanType)
631631
}
632-
plan match {
633-
case aggregate: Aggregate =>
634-
AggregateWithHaving(predicate, aggregate)
635-
case _ =>
636-
Filter(predicate, plan)
637-
}
632+
UnresolvedHaving(predicate, plan)
638633
}
639634

640635
/**

sql/core/src/test/resources/sql-tests/inputs/having.sql

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,9 @@ SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0);
1818
SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1;
1919

2020
-- SPARK-31519: Cast in having aggregate expressions returns the wrong result
21-
SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10
21+
SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10;
22+
23+
-- SPARK-31663: Grouping sets with having clause returns the wrong result
24+
SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY GROUPING SETS ((b), (a, b)) HAVING b > 10;
25+
SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY CUBE(a, b) HAVING b > 10;
26+
SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY ROLLUP(a, b) HAVING b > 10;

sql/core/src/test/resources/sql-tests/results/having.sql.out

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 6
2+
-- Number of queries: 9
33

44

55
-- !query
@@ -55,3 +55,29 @@ SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2,
5555
struct<b:bigint,fake:date>
5656
-- !query output
5757
2 2020-01-01
58+
59+
60+
-- !query
61+
SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY GROUPING SETS ((b), (a, b)) HAVING b > 10
62+
-- !query schema
63+
struct<b:bigint>
64+
-- !query output
65+
2
66+
2
67+
68+
69+
-- !query
70+
SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY CUBE(a, b) HAVING b > 10
71+
-- !query schema
72+
struct<b:bigint>
73+
-- !query output
74+
2
75+
2
76+
77+
78+
-- !query
79+
SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY ROLLUP(a, b) HAVING b > 10
80+
-- !query schema
81+
struct<b:bigint>
82+
-- !query output
83+
2

0 commit comments

Comments
 (0)