diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3505dcccbfd8..91f852da8253 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -533,18 +533,13 @@ class Analyzer( }.asInstanceOf[NamedExpression] } - /* - * Construct [[Aggregate]] operator from Cube/Rollup/GroupingSets. - */ - private def constructAggregate( + private def getFinalGroupByExpressions( selectedGroupByExprs: Seq[Seq[Expression]], - groupByExprs: Seq[Expression], - aggregationExprs: Seq[NamedExpression], - child: LogicalPlan): LogicalPlan = { + groupByExprs: Seq[Expression]): Seq[Expression] = { // In case of ANSI-SQL compliant syntax for GROUPING SETS, groupByExprs is optional and // can be null. In such case, we derive the groupByExprs from the user supplied values for // grouping sets. - val finalGroupByExpressions = if (groupByExprs == Nil) { + if (groupByExprs == Nil) { selectedGroupByExprs.flatten.foldLeft(Seq.empty[Expression]) { (result, currentExpr) => // Only unique expressions are included in the group by expressions and is determined // based on their semantic equality. Example. grouping sets ((a * b), (b * a)) results @@ -558,6 +553,17 @@ class Analyzer( } else { groupByExprs } + } + + /* + * Construct [[Aggregate]] operator from Cube/Rollup/GroupingSets. + */ + private def constructAggregate( + selectedGroupByExprs: Seq[Seq[Expression]], + groupByExprs: Seq[Expression], + aggregationExprs: Seq[NamedExpression], + child: LogicalPlan): LogicalPlan = { + val finalGroupByExpressions = getFinalGroupByExpressions(selectedGroupByExprs, groupByExprs) if (finalGroupByExpressions.size > GroupingID.dataType.defaultSize * 8) { throw new AnalysisException( @@ -595,8 +601,70 @@ class Analyzer( } } - // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { + private def tryResolveHavingCondition(h: UnresolvedHaving): LogicalPlan = { + val aggForResolving = h.child match { + // For CUBE/ROLLUP expressions, to avoid resolving repeatedly, here we delete them from + // groupingExpressions for condition resolving. + case a @ Aggregate(Seq(c @ Cube(groupByExprs)), _, _) => + a.copy(groupingExpressions = groupByExprs) + case a @ Aggregate(Seq(r @ Rollup(groupByExprs)), _, _) => + a.copy(groupingExpressions = groupByExprs) + case g: GroupingSets => + Aggregate( + getFinalGroupByExpressions(g.selectedGroupByExprs, g.groupByExprs), + g.aggregations, g.child) + } + // Try resolving the condition of the filter as though it is in the aggregate clause + val resolvedInfo = + ResolveAggregateFunctions.resolveFilterCondInAggregate(h.havingCondition, aggForResolving) + + // Push the aggregate expressions into the aggregate (if any). + if (resolvedInfo.nonEmpty) { + val (extraAggExprs, resolvedHavingCond) = resolvedInfo.get + val newChild = h.child match { + case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) => + constructAggregate( + cubeExprs(groupByExprs), groupByExprs, aggregateExpressions ++ extraAggExprs, child) + case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) => + constructAggregate( + rollupExprs(groupByExprs), groupByExprs, aggregateExpressions ++ extraAggExprs, child) + case x: GroupingSets => + constructAggregate( + x.selectedGroupByExprs, x.groupByExprs, x.aggregations ++ extraAggExprs, x.child) + } + + // Since the exprId of extraAggExprs will be changed in the constructed aggregate, and the + // aggregateExpressions keeps the input order. So here we build an exprMap to resolve the + // condition again. + val exprMap = extraAggExprs.zip( + newChild.asInstanceOf[Aggregate].aggregateExpressions.takeRight( + extraAggExprs.length)).toMap + val newCond = resolvedHavingCond.transform { + case ne: NamedExpression if exprMap.contains(ne) => exprMap(ne) + } + Project(newChild.output.dropRight(extraAggExprs.length), + Filter(newCond, newChild)) + } else { + h + } + } + + // This require transformDown to resolve having condition when generating aggregate node for + // CUBE/ROLLUP/GROUPING SETS. This also replace grouping()/grouping_id() in resolved + // Filter/Sort. + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown { + case h @ UnresolvedHaving( + _, agg @ Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, _)) + if agg.childrenResolved && (groupByExprs ++ aggregateExpressions).forall(_.resolved) => + tryResolveHavingCondition(h) + case h @ UnresolvedHaving( + _, agg @ Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, _)) + if agg.childrenResolved && (groupByExprs ++ aggregateExpressions).forall(_.resolved) => + tryResolveHavingCondition(h) + case h @ UnresolvedHaving(_, g: GroupingSets) + if g.childrenResolved && g.expressions.forall(_.resolved) => + tryResolveHavingCondition(h) + case a if !a.childrenResolved => a // be sure all of the children are resolved. // Ensure group by expressions and aggregate expressions have been resolved. @@ -1404,7 +1472,7 @@ class Analyzer( } // Skip the having clause here, this will be handled in ResolveAggregateFunctions. - case h: AggregateWithHaving => h + case h: UnresolvedHaving => h case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString(SQLConf.get.maxToStringFields)}") @@ -2049,7 +2117,7 @@ class Analyzer( // Resolve aggregate with having clause to Filter(..., Aggregate()). Note, to avoid wrongly // resolve the having condition expression, here we skip resolving it in ResolveReferences // and transform it to Filter after aggregate is resolved. See more details in SPARK-31519. - case AggregateWithHaving(cond, agg: Aggregate) if agg.resolved => + case UnresolvedHaving(cond, agg: Aggregate) if agg.resolved => resolveHaving(Filter(cond, agg), agg) case f @ Filter(_, agg: Aggregate) if agg.resolved => @@ -2125,13 +2193,13 @@ class Analyzer( condition.find(_.isInstanceOf[AggregateExpression]).isDefined } - def resolveHaving(filter: Filter, agg: Aggregate): LogicalPlan = { - // Try resolving the condition of the filter as though it is in the aggregate clause + def resolveFilterCondInAggregate( + filterCond: Expression, agg: Aggregate): Option[(Seq[NamedExpression], Expression)] = { try { val aggregatedCondition = Aggregate( agg.groupingExpressions, - Alias(filter.condition, "havingCondition")() :: Nil, + Alias(filterCond, "havingCondition")() :: Nil, agg.child) val resolvedOperator = executeSameContext(aggregatedCondition) def resolvedAggregateFilter = @@ -2163,22 +2231,33 @@ class Analyzer( alias.toAttribute } } - - // Push the aggregate expressions into the aggregate (if any). if (aggregateExpressions.nonEmpty) { - Project(agg.output, - Filter(transformedAggregateFilter, - agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions))) + Some(aggregateExpressions, transformedAggregateFilter) } else { - filter + None } } else { - filter + None } } catch { - // Attempting to resolve in the aggregate can result in ambiguity. When this happens, - // just return the original plan. - case ae: AnalysisException => filter + // Attempting to resolve in the aggregate can result in ambiguity. When this happens, + // just return None and the caller side will return the original plan. + case ae: AnalysisException => None + } + } + + def resolveHaving(filter: Filter, agg: Aggregate): LogicalPlan = { + // Try resolving the condition of the filter as though it is in the aggregate clause + val resolvedInfo = resolveFilterCondInAggregate(filter.condition, agg) + + // Push the aggregate expressions into the aggregate (if any). + if (resolvedInfo.nonEmpty) { + val (aggregateExpressions, resolvedHavingCond) = resolvedInfo.get + Project(agg.output, + Filter(resolvedHavingCond, + agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions))) + } else { + filter } } } @@ -2607,12 +2686,12 @@ class Analyzer( case Filter(condition, _) if hasWindowFunction(condition) => failAnalysis("It is not allowed to use window functions inside WHERE clause") - case AggregateWithHaving(condition, _) if hasWindowFunction(condition) => + case UnresolvedHaving(condition, _) if hasWindowFunction(condition) => failAnalysis("It is not allowed to use window functions inside HAVING clause") // Aggregate with Having clause. This rule works with an unresolved Aggregate because // a resolved Aggregate will not have Window Functions. - case f @ AggregateWithHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child)) + case f @ UnresolvedHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child)) if child.resolved && hasWindowFunction(aggregateExprs) && a.expressions.forall(_.resolved) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 806cdeb95cca..b28be042c43f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -540,11 +540,12 @@ case class UnresolvedOrdinal(ordinal: Int) } /** - * Represents unresolved aggregate with having clause, it is turned by the analyzer into a Filter. + * Represents unresolved having clause, the child for it can be Aggregate, GroupingSets, Rollup + * and Cube. It is turned by the analyzer into a Filter. */ -case class AggregateWithHaving( +case class UnresolvedHaving( havingCondition: Expression, - child: Aggregate) + child: LogicalPlan) extends UnaryNode { override lazy val resolved: Boolean = false override def output: Seq[Attribute] = child.output diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index f135f50493ed..fe3fea5e35b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -368,7 +368,7 @@ package object dsl { groupingExprs: Expression*)( aggregateExprs: Expression*)( havingCondition: Expression): LogicalPlan = { - AggregateWithHaving(havingCondition, + UnresolvedHaving(havingCondition, groupBy(groupingExprs: _*)(aggregateExprs: _*).asInstanceOf[Aggregate]) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 97750f467adb..c0cecf8536c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -629,12 +629,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case p: Predicate => p case e => Cast(e, BooleanType) } - plan match { - case aggregate: Aggregate => - AggregateWithHaving(predicate, aggregate) - case _ => - Filter(predicate, plan) - } + UnresolvedHaving(predicate, plan) } /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/having.sql b/sql/core/src/test/resources/sql-tests/inputs/having.sql index 6868b5902939..3b75be19b567 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/having.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/having.sql @@ -18,4 +18,9 @@ SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0); SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1; -- SPARK-31519: Cast in having aggregate expressions returns the wrong result -SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10 +SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10; + +-- SPARK-31663: Grouping sets with having clause returns the wrong result +SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY GROUPING SETS ((b), (a, b)) HAVING b > 10; +SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY CUBE(a, b) HAVING b > 10; +SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY ROLLUP(a, b) HAVING b > 10; diff --git a/sql/core/src/test/resources/sql-tests/results/having.sql.out b/sql/core/src/test/resources/sql-tests/results/having.sql.out index aa8ff7372358..1b3ac7865159 100644 --- a/sql/core/src/test/resources/sql-tests/results/having.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/having.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 6 +-- Number of queries: 9 -- !query @@ -55,3 +55,29 @@ SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, struct -- !query output 2 2020-01-01 + + +-- !query +SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY GROUPING SETS ((b), (a, b)) HAVING b > 10 +-- !query schema +struct +-- !query output +2 +2 + + +-- !query +SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY CUBE(a, b) HAVING b > 10 +-- !query schema +struct +-- !query output +2 +2 + + +-- !query +SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY ROLLUP(a, b) HAVING b > 10 +-- !query schema +struct +-- !query output +2