Skip to content

Commit ced9568

Browse files
committed
WIP
1 parent 1d1bb79 commit ced9568

File tree

5 files changed

+109
-25
lines changed

5 files changed

+109
-25
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 74 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -595,8 +595,50 @@ class Analyzer(
595595
}
596596
}
597597

598+
private def tryResolveHavingCondition(
599+
a: AggregateWithHaving, havingCond: Expression, agg: LogicalPlan): LogicalPlan = {
600+
val aggForResolver = agg match {
601+
case a: Aggregate =>
602+
// For CUBE/ROLLUP expressions, since they don't have their own logical plan, we need
603+
// to delete them from groupingExpressions for condition resolving.
604+
a.copy(groupingExpressions = Seq.empty)
605+
case g: GroupingSets =>
606+
Aggregate(g.groupByExprs, g.aggregations, g.child)
607+
}
608+
// Try resolving the condition of the filter as though it is in the aggregate clause
609+
val (aggregateExpressions, transformedAggregateFilter) =
610+
ResolveAggregateFunctions.resolveFilterCondInAggregate(
611+
havingCond, aggForResolver, true)
612+
613+
// Push the aggregate expressions into the aggregate (if any).
614+
if (aggregateExpressions.nonEmpty) {
615+
val newChild = agg match {
616+
case a: Aggregate =>
617+
a.copy(aggregateExpressions = a.aggregateExpressions ++ aggregateExpressions)
618+
case g: GroupingSets =>
619+
g.copy(aggregations = g.aggregations ++ aggregateExpressions)
620+
}
621+
Project(agg.output,
622+
Filter(transformedAggregateFilter.get, newChild))
623+
} else {
624+
a
625+
}
626+
}
627+
598628
// This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort
599-
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
629+
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown {
630+
case a @ AggregateWithHaving(
631+
havingCondition, agg @ Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, _))
632+
if (groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
633+
tryResolveHavingCondition(a, havingCondition, agg)
634+
case a @ AggregateWithHaving(
635+
havingCondition, agg @ Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, _))
636+
if (groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
637+
tryResolveHavingCondition(a, havingCondition, agg)
638+
case a @ AggregateWithHaving(
639+
havingCondition, g: GroupingSets) if g.expressions.forall(_.resolved) =>
640+
tryResolveHavingCondition(a, havingCondition, g)
641+
600642
case a if !a.childrenResolved => a // be sure all of the children are resolved.
601643

602644
// Ensure group by expressions and aggregate expressions have been resolved.
@@ -2125,13 +2167,17 @@ class Analyzer(
21252167
condition.find(_.isInstanceOf[AggregateExpression]).isDefined
21262168
}
21272169

2128-
def resolveHaving(filter: Filter, agg: Aggregate): LogicalPlan = {
2129-
// Try resolving the condition of the filter as though it is in the aggregate clause
2170+
def resolveFilterCondInAggregate(
2171+
filterCond: Expression,
2172+
agg: Aggregate,
2173+
resolveFilterNotInAggOutput: Boolean = false)
2174+
: (Seq[NamedExpression], Option[Expression]) = {
2175+
val aggregateExpressions = ArrayBuffer.empty[NamedExpression]
21302176
try {
21312177
val aggregatedCondition =
21322178
Aggregate(
21332179
agg.groupingExpressions,
2134-
Alias(filter.condition, "havingCondition")() :: Nil,
2180+
Alias(filterCond, "havingCondition")() :: Nil,
21352181
agg.child)
21362182
val resolvedOperator = executeSameContext(aggregatedCondition)
21372183
def resolvedAggregateFilter =
@@ -2144,13 +2190,18 @@ class Analyzer(
21442190
if (resolvedOperator.resolved) {
21452191
// Try to replace all aggregate expressions in the filter by an alias.
21462192
val aggregateExpressions = ArrayBuffer.empty[NamedExpression]
2193+
val groupingExpressions = if (resolveFilterNotInAggOutput) {
2194+
agg.groupingExpressions :+ resolvedAggregateFilter
2195+
} else {
2196+
agg.groupingExpressions
2197+
}
21472198
val transformedAggregateFilter = resolvedAggregateFilter.transform {
21482199
case ae: AggregateExpression =>
21492200
val alias = Alias(ae, ae.toString)()
21502201
aggregateExpressions += alias
21512202
alias.toAttribute
21522203
// Grouping functions are handled in the rule [[ResolveGroupingAnalytics]].
2153-
case e: Expression if agg.groupingExpressions.exists(_.semanticEquals(e)) &&
2204+
case e: Expression if groupingExpressions.exists(_.semanticEquals(e)) &&
21542205
!ResolveGroupingAnalytics.hasGroupingFunction(e) &&
21552206
!agg.output.exists(_.semanticEquals(e)) =>
21562207
e match {
@@ -2163,22 +2214,29 @@ class Analyzer(
21632214
alias.toAttribute
21642215
}
21652216
}
2166-
2167-
// Push the aggregate expressions into the aggregate (if any).
2168-
if (aggregateExpressions.nonEmpty) {
2169-
Project(agg.output,
2170-
Filter(transformedAggregateFilter,
2171-
agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions)))
2172-
} else {
2173-
filter
2174-
}
2217+
(aggregateExpressions, Some(transformedAggregateFilter))
21752218
} else {
2176-
filter
2219+
(aggregateExpressions, None)
21772220
}
21782221
} catch {
21792222
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
21802223
// just return the original plan.
2181-
case ae: AnalysisException => filter
2224+
case ae: AnalysisException => (aggregateExpressions, None)
2225+
}
2226+
}
2227+
2228+
def resolveHaving(filter: Filter, agg: Aggregate): LogicalPlan = {
2229+
// Try resolving the condition of the filter as though it is in the aggregate clause
2230+
val (aggregateExpressions, transformedAggregateFilter) =
2231+
resolveFilterCondInAggregate(filter.condition, agg)
2232+
2233+
// Push the aggregate expressions into the aggregate (if any).
2234+
if (aggregateExpressions.nonEmpty) {
2235+
Project(agg.output,
2236+
Filter(transformedAggregateFilter.get,
2237+
agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions)))
2238+
} else {
2239+
filter
21822240
}
21832241
}
21842242
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ case class UnresolvedOrdinal(ordinal: Int)
544544
*/
545545
case class AggregateWithHaving(
546546
havingCondition: Expression,
547-
child: Aggregate)
547+
child: LogicalPlan)
548548
extends UnaryNode {
549549
override lazy val resolved: Boolean = false
550550
override def output: Seq[Attribute] = child.output

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -629,12 +629,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
629629
case p: Predicate => p
630630
case e => Cast(e, BooleanType)
631631
}
632-
plan match {
633-
case aggregate: Aggregate =>
634-
AggregateWithHaving(predicate, aggregate)
635-
case _ =>
636-
Filter(predicate, plan)
637-
}
632+
AggregateWithHaving(predicate, plan)
638633
}
639634

640635
/**

sql/core/src/test/resources/sql-tests/inputs/having.sql

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,9 @@ SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0);
1818
SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1;
1919

2020
-- SPARK-31519: Cast in having aggregate expressions returns the wrong result
21-
SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10
21+
SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10;
22+
23+
-- SPARK-31663: Grouping sets with having clause returns the wrong result
24+
SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY GROUPING SETS ((b), (a, b)) HAVING b > 10;
25+
SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY CUBE(a, b) HAVING b > 10;
26+
SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY ROLLUP(a, b) HAVING b > 10;

sql/core/src/test/resources/sql-tests/results/having.sql.out

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 6
2+
-- Number of queries: 9
33

44

55
-- !query
@@ -55,3 +55,29 @@ SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2,
5555
struct<b:bigint,fake:date>
5656
-- !query output
5757
2 2020-01-01
58+
59+
60+
-- !query
61+
SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY GROUPING SETS ((b), (a, b)) HAVING b > 10
62+
-- !query schema
63+
struct<b:bigint>
64+
-- !query output
65+
2
66+
2
67+
68+
69+
-- !query
70+
SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY CUBE(a, b) HAVING b > 10
71+
-- !query schema
72+
struct<b:bigint>
73+
-- !query output
74+
2
75+
2
76+
77+
78+
-- !query
79+
SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY ROLLUP(a, b) HAVING b > 10
80+
-- !query schema
81+
struct<b:bigint>
82+
-- !query output
83+
2

0 commit comments

Comments
 (0)