Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -533,18 +533,13 @@ class Analyzer(
}.asInstanceOf[NamedExpression]
}

/*
* Construct [[Aggregate]] operator from Cube/Rollup/GroupingSets.
*/
private def constructAggregate(
private def getFinalGroupByExpressions(
selectedGroupByExprs: Seq[Seq[Expression]],
groupByExprs: Seq[Expression],
aggregationExprs: Seq[NamedExpression],
child: LogicalPlan): LogicalPlan = {
groupByExprs: Seq[Expression]): Seq[Expression] = {
// In case of ANSI-SQL compliant syntax for GROUPING SETS, groupByExprs is optional and
// can be null. In such case, we derive the groupByExprs from the user supplied values for
// grouping sets.
val finalGroupByExpressions = if (groupByExprs == Nil) {
if (groupByExprs == Nil) {
selectedGroupByExprs.flatten.foldLeft(Seq.empty[Expression]) { (result, currentExpr) =>
// Only unique expressions are included in the group by expressions and is determined
// based on their semantic equality. Example. grouping sets ((a * b), (b * a)) results
Expand All @@ -558,6 +553,17 @@ class Analyzer(
} else {
groupByExprs
}
}

/*
* Construct [[Aggregate]] operator from Cube/Rollup/GroupingSets.
*/
private def constructAggregate(
selectedGroupByExprs: Seq[Seq[Expression]],
groupByExprs: Seq[Expression],
aggregationExprs: Seq[NamedExpression],
child: LogicalPlan): LogicalPlan = {
val finalGroupByExpressions = getFinalGroupByExpressions(selectedGroupByExprs, groupByExprs)

if (finalGroupByExpressions.size > GroupingID.dataType.defaultSize * 8) {
throw new AnalysisException(
Expand Down Expand Up @@ -595,8 +601,70 @@ class Analyzer(
}
}

// This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
private def tryResolveHavingCondition(h: UnresolvedHaving): LogicalPlan = {
val aggForResolving = h.child match {
// For CUBE/ROLLUP expressions, to avoid resolving repeatedly, here we delete them from
// groupingExpressions for condition resolving.
case a @ Aggregate(Seq(c @ Cube(groupByExprs)), _, _) =>
a.copy(groupingExpressions = groupByExprs)
case a @ Aggregate(Seq(r @ Rollup(groupByExprs)), _, _) =>
a.copy(groupingExpressions = groupByExprs)
case g: GroupingSets =>
Aggregate(
getFinalGroupByExpressions(g.selectedGroupByExprs, g.groupByExprs),
g.aggregations, g.child)
}
// Try resolving the condition of the filter as though it is in the aggregate clause
val resolvedInfo =
ResolveAggregateFunctions.resolveFilterCondInAggregate(h.havingCondition, aggForResolving)

// Push the aggregate expressions into the aggregate (if any).
if (resolvedInfo.nonEmpty) {
val (extraAggExprs, resolvedHavingCond) = resolvedInfo.get
val newChild = h.child match {
case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) =>
constructAggregate(
cubeExprs(groupByExprs), groupByExprs, aggregateExpressions ++ extraAggExprs, child)
case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) =>
constructAggregate(
rollupExprs(groupByExprs), groupByExprs, aggregateExpressions ++ extraAggExprs, child)
case x: GroupingSets =>
constructAggregate(
x.selectedGroupByExprs, x.groupByExprs, x.aggregations ++ extraAggExprs, x.child)
}

// Since the exprId of extraAggExprs will be changed in the constructed aggregate, and the
// aggregateExpressions keeps the input order. So here we build an exprMap to resolve the
// condition again.
val exprMap = extraAggExprs.zip(
newChild.asInstanceOf[Aggregate].aggregateExpressions.takeRight(
extraAggExprs.length)).toMap
val newCond = resolvedHavingCond.transform {
case ne: NamedExpression if exprMap.contains(ne) => exprMap(ne)
}
Project(newChild.output.dropRight(extraAggExprs.length),
Filter(newCond, newChild))
} else {
h
}
}

// This require transformDown to resolve having condition when generating aggregate node for
// CUBE/ROLLUP/GROUPING SETS. This also replace grouping()/grouping_id() in resolved
// Filter/Sort.
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown {
case h @ UnresolvedHaving(
_, agg @ Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, _))
if agg.childrenResolved && (groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
tryResolveHavingCondition(h)
case h @ UnresolvedHaving(
_, agg @ Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, _))
if agg.childrenResolved && (groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
tryResolveHavingCondition(h)
case h @ UnresolvedHaving(_, g: GroupingSets)
if g.childrenResolved && g.expressions.forall(_.resolved) =>
tryResolveHavingCondition(h)

case a if !a.childrenResolved => a // be sure all of the children are resolved.

// Ensure group by expressions and aggregate expressions have been resolved.
Expand Down Expand Up @@ -1404,7 +1472,7 @@ class Analyzer(
}

// Skip the having clause here, this will be handled in ResolveAggregateFunctions.
case h: AggregateWithHaving => h
case h: UnresolvedHaving => h

case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString(SQLConf.get.maxToStringFields)}")
Expand Down Expand Up @@ -2049,7 +2117,7 @@ class Analyzer(
// Resolve aggregate with having clause to Filter(..., Aggregate()). Note, to avoid wrongly
// resolve the having condition expression, here we skip resolving it in ResolveReferences
// and transform it to Filter after aggregate is resolved. See more details in SPARK-31519.
case AggregateWithHaving(cond, agg: Aggregate) if agg.resolved =>
case UnresolvedHaving(cond, agg: Aggregate) if agg.resolved =>
resolveHaving(Filter(cond, agg), agg)

case f @ Filter(_, agg: Aggregate) if agg.resolved =>
Expand Down Expand Up @@ -2125,13 +2193,13 @@ class Analyzer(
condition.find(_.isInstanceOf[AggregateExpression]).isDefined
}

def resolveHaving(filter: Filter, agg: Aggregate): LogicalPlan = {
// Try resolving the condition of the filter as though it is in the aggregate clause
def resolveFilterCondInAggregate(
filterCond: Expression, agg: Aggregate): Option[(Seq[NamedExpression], Expression)] = {
try {
val aggregatedCondition =
Aggregate(
agg.groupingExpressions,
Alias(filter.condition, "havingCondition")() :: Nil,
Alias(filterCond, "havingCondition")() :: Nil,
agg.child)
val resolvedOperator = executeSameContext(aggregatedCondition)
def resolvedAggregateFilter =
Expand Down Expand Up @@ -2163,22 +2231,33 @@ class Analyzer(
alias.toAttribute
}
}

// Push the aggregate expressions into the aggregate (if any).
if (aggregateExpressions.nonEmpty) {
Project(agg.output,
Filter(transformedAggregateFilter,
agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions)))
Some(aggregateExpressions, transformedAggregateFilter)
} else {
filter
None
}
} else {
filter
None
}
} catch {
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
// just return the original plan.
case ae: AnalysisException => filter
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
// just return None and the caller side will return the original plan.
case ae: AnalysisException => None
}
}

def resolveHaving(filter: Filter, agg: Aggregate): LogicalPlan = {
// Try resolving the condition of the filter as though it is in the aggregate clause
val resolvedInfo = resolveFilterCondInAggregate(filter.condition, agg)

// Push the aggregate expressions into the aggregate (if any).
if (resolvedInfo.nonEmpty) {
val (aggregateExpressions, resolvedHavingCond) = resolvedInfo.get
Project(agg.output,
Filter(resolvedHavingCond,
agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions)))
} else {
filter
}
}
}
Expand Down Expand Up @@ -2607,12 +2686,12 @@ class Analyzer(
case Filter(condition, _) if hasWindowFunction(condition) =>
failAnalysis("It is not allowed to use window functions inside WHERE clause")

case AggregateWithHaving(condition, _) if hasWindowFunction(condition) =>
case UnresolvedHaving(condition, _) if hasWindowFunction(condition) =>
failAnalysis("It is not allowed to use window functions inside HAVING clause")

// Aggregate with Having clause. This rule works with an unresolved Aggregate because
// a resolved Aggregate will not have Window Functions.
case f @ AggregateWithHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
case f @ UnresolvedHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
if child.resolved &&
hasWindowFunction(aggregateExprs) &&
a.expressions.forall(_.resolved) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -540,11 +540,12 @@ case class UnresolvedOrdinal(ordinal: Int)
}

/**
* Represents unresolved aggregate with having clause, it is turned by the analyzer into a Filter.
* Represents unresolved having clause, the child for it can be Aggregate, GroupingSets, Rollup
* and Cube. It is turned by the analyzer into a Filter.
*/
case class AggregateWithHaving(
case class UnresolvedHaving(
havingCondition: Expression,
child: Aggregate)
child: LogicalPlan)
extends UnaryNode {
override lazy val resolved: Boolean = false
override def output: Seq[Attribute] = child.output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ package object dsl {
groupingExprs: Expression*)(
aggregateExprs: Expression*)(
havingCondition: Expression): LogicalPlan = {
AggregateWithHaving(havingCondition,
UnresolvedHaving(havingCondition,
groupBy(groupingExprs: _*)(aggregateExprs: _*).asInstanceOf[Aggregate])
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -629,12 +629,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
case p: Predicate => p
case e => Cast(e, BooleanType)
}
plan match {
case aggregate: Aggregate =>
AggregateWithHaving(predicate, aggregate)
case _ =>
Filter(predicate, plan)
}
UnresolvedHaving(predicate, plan)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does global aggregate still work? e.g. UnresolvedHaving(Project(agg_func ...))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it still works, the UnresolvedHaving will be changed to Filter in rule ResolveAggregateFunction.

}

/**
Expand Down
7 changes: 6 additions & 1 deletion sql/core/src/test/resources/sql-tests/inputs/having.sql
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,9 @@ SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0);
SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1;

-- SPARK-31519: Cast in having aggregate expressions returns the wrong result
SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10
SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10;

-- SPARK-31663: Grouping sets with having clause returns the wrong result
SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY GROUPING SETS ((b), (a, b)) HAVING b > 10;
SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY CUBE(a, b) HAVING b > 10;
SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY ROLLUP(a, b) HAVING b > 10;
28 changes: 27 additions & 1 deletion sql/core/src/test/resources/sql-tests/results/having.sql.out
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 6
-- Number of queries: 9


-- !query
Expand Down Expand Up @@ -55,3 +55,29 @@ SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2,
struct<b:bigint,fake:date>
-- !query output
2 2020-01-01


-- !query
SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY GROUPING SETS ((b), (a, b)) HAVING b > 10
-- !query schema
struct<b:bigint>
-- !query output
2
2


-- !query
SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY CUBE(a, b) HAVING b > 10
-- !query schema
struct<b:bigint>
-- !query output
2
2


-- !query
SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY ROLLUP(a, b) HAVING b > 10
-- !query schema
struct<b:bigint>
-- !query output
2