@@ -533,18 +533,13 @@ class Analyzer(
533533 }.asInstanceOf [NamedExpression ]
534534 }
535535
536- /*
537- * Construct [[Aggregate]] operator from Cube/Rollup/GroupingSets.
538- */
539- private def constructAggregate (
536+ private def getFinalGroupByExpressions (
540537 selectedGroupByExprs : Seq [Seq [Expression ]],
541- groupByExprs : Seq [Expression ],
542- aggregationExprs : Seq [NamedExpression ],
543- child : LogicalPlan ): LogicalPlan = {
538+ groupByExprs : Seq [Expression ]): Seq [Expression ] = {
544539 // In case of ANSI-SQL compliant syntax for GROUPING SETS, groupByExprs is optional and
545540 // can be null. In such case, we derive the groupByExprs from the user supplied values for
546541 // grouping sets.
547- val finalGroupByExpressions = if (groupByExprs == Nil ) {
542+ if (groupByExprs == Nil ) {
548543 selectedGroupByExprs.flatten.foldLeft(Seq .empty[Expression ]) { (result, currentExpr) =>
549544 // Only unique expressions are included in the group by expressions and is determined
550545 // based on their semantic equality. Example. grouping sets ((a * b), (b * a)) results
@@ -558,6 +553,17 @@ class Analyzer(
558553 } else {
559554 groupByExprs
560555 }
556+ }
557+
558+ /*
559+ * Construct [[Aggregate]] operator from Cube/Rollup/GroupingSets.
560+ */
561+ private def constructAggregate (
562+ selectedGroupByExprs : Seq [Seq [Expression ]],
563+ groupByExprs : Seq [Expression ],
564+ aggregationExprs : Seq [NamedExpression ],
565+ child : LogicalPlan ): LogicalPlan = {
566+ val finalGroupByExpressions = getFinalGroupByExpressions(selectedGroupByExprs, groupByExprs)
561567
562568 if (finalGroupByExpressions.size > GroupingID .dataType.defaultSize * 8 ) {
563569 throw new AnalysisException (
@@ -595,8 +601,70 @@ class Analyzer(
595601 }
596602 }
597603
598- // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort
599- def apply (plan : LogicalPlan ): LogicalPlan = plan resolveOperatorsUp {
604+ private def tryResolveHavingCondition (h : UnresolvedHaving ): LogicalPlan = {
605+ val aggForResolving = h.child match {
606+ // For CUBE/ROLLUP expressions, to avoid resolving repeatedly, here we delete them from
607+ // groupingExpressions for condition resolving.
608+ case a @ Aggregate (Seq (c @ Cube (groupByExprs)), _, _) =>
609+ a.copy(groupingExpressions = groupByExprs)
610+ case a @ Aggregate (Seq (r @ Rollup (groupByExprs)), _, _) =>
611+ a.copy(groupingExpressions = groupByExprs)
612+ case g : GroupingSets =>
613+ Aggregate (
614+ getFinalGroupByExpressions(g.selectedGroupByExprs, g.groupByExprs),
615+ g.aggregations, g.child)
616+ }
617+ // Try resolving the condition of the filter as though it is in the aggregate clause
618+ val resolvedInfo =
619+ ResolveAggregateFunctions .resolveFilterCondInAggregate(h.havingCondition, aggForResolving)
620+
621+ // Push the aggregate expressions into the aggregate (if any).
622+ if (resolvedInfo.nonEmpty) {
623+ val (extraAggExprs, resolvedHavingCond) = resolvedInfo.get
624+ val newChild = h.child match {
625+ case Aggregate (Seq (c @ Cube (groupByExprs)), aggregateExpressions, child) =>
626+ constructAggregate(
627+ cubeExprs(groupByExprs), groupByExprs, aggregateExpressions ++ extraAggExprs, child)
628+ case Aggregate (Seq (r @ Rollup (groupByExprs)), aggregateExpressions, child) =>
629+ constructAggregate(
630+ rollupExprs(groupByExprs), groupByExprs, aggregateExpressions ++ extraAggExprs, child)
631+ case x : GroupingSets =>
632+ constructAggregate(
633+ x.selectedGroupByExprs, x.groupByExprs, x.aggregations ++ extraAggExprs, x.child)
634+ }
635+
636+ // Since the exprId of extraAggExprs will be changed in the constructed aggregate, and the
637+ // aggregateExpressions keeps the input order. So here we build an exprMap to resolve the
638+ // condition again.
639+ val exprMap = extraAggExprs.zip(
640+ newChild.asInstanceOf [Aggregate ].aggregateExpressions.takeRight(
641+ extraAggExprs.length)).toMap
642+ val newCond = resolvedHavingCond.transform {
643+ case ne : NamedExpression if exprMap.contains(ne) => exprMap(ne)
644+ }
645+ Project (newChild.output.dropRight(extraAggExprs.length),
646+ Filter (newCond, newChild))
647+ } else {
648+ h
649+ }
650+ }
651+
652+ // This require transformDown to resolve having condition when generating aggregate node for
653+ // CUBE/ROLLUP/GROUPING SETS. This also replace grouping()/grouping_id() in resolved
654+ // Filter/Sort.
655+ def apply (plan : LogicalPlan ): LogicalPlan = plan resolveOperatorsDown {
656+ case h @ UnresolvedHaving (
657+ _, agg @ Aggregate (Seq (c @ Cube (groupByExprs)), aggregateExpressions, _))
658+ if agg.childrenResolved && (groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
659+ tryResolveHavingCondition(h)
660+ case h @ UnresolvedHaving (
661+ _, agg @ Aggregate (Seq (r @ Rollup (groupByExprs)), aggregateExpressions, _))
662+ if agg.childrenResolved && (groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
663+ tryResolveHavingCondition(h)
664+ case h @ UnresolvedHaving (_, g : GroupingSets )
665+ if g.childrenResolved && g.expressions.forall(_.resolved) =>
666+ tryResolveHavingCondition(h)
667+
600668 case a if ! a.childrenResolved => a // be sure all of the children are resolved.
601669
602670 // Ensure group by expressions and aggregate expressions have been resolved.
@@ -1404,7 +1472,7 @@ class Analyzer(
14041472 }
14051473
14061474 // Skip the having clause here, this will be handled in ResolveAggregateFunctions.
1407- case h : AggregateWithHaving => h
1475+ case h : UnresolvedHaving => h
14081476
14091477 case q : LogicalPlan =>
14101478 logTrace(s " Attempting to resolve ${q.simpleString(SQLConf .get.maxToStringFields)}" )
@@ -2049,7 +2117,7 @@ class Analyzer(
20492117 // Resolve aggregate with having clause to Filter(..., Aggregate()). Note, to avoid wrongly
20502118 // resolve the having condition expression, here we skip resolving it in ResolveReferences
20512119 // and transform it to Filter after aggregate is resolved. See more details in SPARK-31519.
2052- case AggregateWithHaving (cond, agg : Aggregate ) if agg.resolved =>
2120+ case UnresolvedHaving (cond, agg : Aggregate ) if agg.resolved =>
20532121 resolveHaving(Filter (cond, agg), agg)
20542122
20552123 case f @ Filter (_, agg : Aggregate ) if agg.resolved =>
@@ -2125,13 +2193,13 @@ class Analyzer(
21252193 condition.find(_.isInstanceOf [AggregateExpression ]).isDefined
21262194 }
21272195
2128- def resolveHaving ( filter : Filter , agg : Aggregate ) : LogicalPlan = {
2129- // Try resolving the condition of the filter as though it is in the aggregate clause
2196+ def resolveFilterCondInAggregate (
2197+ filterCond : Expression , agg : Aggregate ) : Option [( Seq [ NamedExpression ], Expression )] = {
21302198 try {
21312199 val aggregatedCondition =
21322200 Aggregate (
21332201 agg.groupingExpressions,
2134- Alias (filter.condition , " havingCondition" )() :: Nil ,
2202+ Alias (filterCond , " havingCondition" )() :: Nil ,
21352203 agg.child)
21362204 val resolvedOperator = executeSameContext(aggregatedCondition)
21372205 def resolvedAggregateFilter =
@@ -2163,22 +2231,33 @@ class Analyzer(
21632231 alias.toAttribute
21642232 }
21652233 }
2166-
2167- // Push the aggregate expressions into the aggregate (if any).
21682234 if (aggregateExpressions.nonEmpty) {
2169- Project (agg.output,
2170- Filter (transformedAggregateFilter,
2171- agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions)))
2235+ Some (aggregateExpressions, transformedAggregateFilter)
21722236 } else {
2173- filter
2237+ None
21742238 }
21752239 } else {
2176- filter
2240+ None
21772241 }
21782242 } catch {
2179- // Attempting to resolve in the aggregate can result in ambiguity. When this happens,
2180- // just return the original plan.
2181- case ae : AnalysisException => filter
2243+ // Attempting to resolve in the aggregate can result in ambiguity. When this happens,
2244+ // just return None and the caller side will return the original plan.
2245+ case ae : AnalysisException => None
2246+ }
2247+ }
2248+
2249+ def resolveHaving (filter : Filter , agg : Aggregate ): LogicalPlan = {
2250+ // Try resolving the condition of the filter as though it is in the aggregate clause
2251+ val resolvedInfo = resolveFilterCondInAggregate(filter.condition, agg)
2252+
2253+ // Push the aggregate expressions into the aggregate (if any).
2254+ if (resolvedInfo.nonEmpty) {
2255+ val (aggregateExpressions, resolvedHavingCond) = resolvedInfo.get
2256+ Project (agg.output,
2257+ Filter (resolvedHavingCond,
2258+ agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions)))
2259+ } else {
2260+ filter
21822261 }
21832262 }
21842263 }
@@ -2607,12 +2686,12 @@ class Analyzer(
26072686 case Filter (condition, _) if hasWindowFunction(condition) =>
26082687 failAnalysis(" It is not allowed to use window functions inside WHERE clause" )
26092688
2610- case AggregateWithHaving (condition, _) if hasWindowFunction(condition) =>
2689+ case UnresolvedHaving (condition, _) if hasWindowFunction(condition) =>
26112690 failAnalysis(" It is not allowed to use window functions inside HAVING clause" )
26122691
26132692 // Aggregate with Having clause. This rule works with an unresolved Aggregate because
26142693 // a resolved Aggregate will not have Window Functions.
2615- case f @ AggregateWithHaving (condition, a @ Aggregate (groupingExprs, aggregateExprs, child))
2694+ case f @ UnresolvedHaving (condition, a @ Aggregate (groupingExprs, aggregateExprs, child))
26162695 if child.resolved &&
26172696 hasWindowFunction(aggregateExprs) &&
26182697 a.expressions.forall(_.resolved) =>
0 commit comments