Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -595,8 +595,58 @@ class Analyzer(
}
}

private def tryResolveHavingCondition(
a: UnresolvedHaving, havingCond: Expression, agg: LogicalPlan): LogicalPlan = {
val aggForResolving = agg match {
case a: Aggregate =>
// For CUBE/ROLLUP expressions, to avoid resolving repeatedly, here we delete them from
// groupingExpressions for condition resolving.
a.copy(groupingExpressions = Seq.empty)
case g: GroupingSets =>
Aggregate(g.groupByExprs, g.aggregations, g.child)
}
// Try resolving the condition of the filter as though it is in the aggregate clause
val (extraAggExprs, transformedAggregateFilter) =
ResolveAggregateFunctions.resolveFilterCondInAggregate(
havingCond, aggForResolving, resolveFilterNotInGroupingExprs = true)

// Push the aggregate expressions into the aggregate (if any).
if (extraAggExprs.nonEmpty) {
val newChild = agg match {
case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) =>
constructAggregate(
cubeExprs(groupByExprs), groupByExprs, aggregateExpressions ++ extraAggExprs, child)
case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) =>
constructAggregate(
rollupExprs(groupByExprs), groupByExprs, aggregateExpressions ++ extraAggExprs, child)
case x: GroupingSets =>
constructAggregate(
x.selectedGroupByExprs, x.groupByExprs, x.aggregations ++ extraAggExprs, x.child)
}
Project(newChild.output.filter(_.name != "havingCondition"),
Filter(transformedAggregateFilter.get, newChild))
} else {
a
}
}

// This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown {
case a @ UnresolvedHaving(
havingCondition, agg @ Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, _))
if agg.childrenResolved && !havingCondition.isInstanceOf[SubqueryExpression]
&& (groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
tryResolveHavingCondition(a, havingCondition, agg)
case a @ UnresolvedHaving(
havingCondition, agg @ Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, _))
if agg.childrenResolved && !havingCondition.isInstanceOf[SubqueryExpression]
&& (groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
tryResolveHavingCondition(a, havingCondition, agg)
case a @ UnresolvedHaving(havingCondition, g: GroupingSets)
if g.childrenResolved && !havingCondition.isInstanceOf[SubqueryExpression]
&& g.expressions.forall(_.resolved) =>
tryResolveHavingCondition(a, havingCondition, g)

case a if !a.childrenResolved => a // be sure all of the children are resolved.

// Ensure group by expressions and aggregate expressions have been resolved.
Expand Down Expand Up @@ -1404,7 +1454,7 @@ class Analyzer(
}

// Skip the having clause here, this will be handled in ResolveAggregateFunctions.
case h: AggregateWithHaving => h
case h: UnresolvedHaving => h

case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString(SQLConf.get.maxToStringFields)}")
Expand Down Expand Up @@ -2049,7 +2099,7 @@ class Analyzer(
// Resolve aggregate with having clause to Filter(..., Aggregate()). Note, to avoid wrongly
// resolve the having condition expression, here we skip resolving it in ResolveReferences
// and transform it to Filter after aggregate is resolved. See more details in SPARK-31519.
case AggregateWithHaving(cond, agg: Aggregate) if agg.resolved =>
case UnresolvedHaving(cond, agg: Aggregate) if agg.resolved =>
resolveHaving(Filter(cond, agg), agg)

case f @ Filter(_, agg: Aggregate) if agg.resolved =>
Expand Down Expand Up @@ -2125,13 +2175,17 @@ class Analyzer(
condition.find(_.isInstanceOf[AggregateExpression]).isDefined
}

def resolveHaving(filter: Filter, agg: Aggregate): LogicalPlan = {
// Try resolving the condition of the filter as though it is in the aggregate clause
def resolveFilterCondInAggregate(
filterCond: Expression,
agg: Aggregate,
resolveFilterNotInGroupingExprs: Boolean = false)
: (Seq[NamedExpression], Option[Expression]) = {
val aggregateExpressions = ArrayBuffer.empty[NamedExpression]
try {
val aggregatedCondition =
Aggregate(
agg.groupingExpressions,
Alias(filter.condition, "havingCondition")() :: Nil,
Alias(filterCond, "havingCondition")() :: Nil,
agg.child)
val resolvedOperator = executeSameContext(aggregatedCondition)
def resolvedAggregateFilter =
Expand All @@ -2144,13 +2198,18 @@ class Analyzer(
if (resolvedOperator.resolved) {
// Try to replace all aggregate expressions in the filter by an alias.
val aggregateExpressions = ArrayBuffer.empty[NamedExpression]
val groupingExpressions = if (resolveFilterNotInGroupingExprs) {
agg.groupingExpressions :+ resolvedAggregateFilter
} else {
agg.groupingExpressions
}
val transformedAggregateFilter = resolvedAggregateFilter.transform {
case ae: AggregateExpression =>
val alias = Alias(ae, ae.toString)()
aggregateExpressions += alias
alias.toAttribute
// Grouping functions are handled in the rule [[ResolveGroupingAnalytics]].
case e: Expression if agg.groupingExpressions.exists(_.semanticEquals(e)) &&
case e: Expression if groupingExpressions.exists(_.semanticEquals(e)) &&
!ResolveGroupingAnalytics.hasGroupingFunction(e) &&
!agg.output.exists(_.semanticEquals(e)) =>
e match {
Expand All @@ -2163,22 +2222,29 @@ class Analyzer(
alias.toAttribute
}
}

// Push the aggregate expressions into the aggregate (if any).
if (aggregateExpressions.nonEmpty) {
Project(agg.output,
Filter(transformedAggregateFilter,
agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions)))
} else {
filter
}
(aggregateExpressions, Some(transformedAggregateFilter))
} else {
filter
(aggregateExpressions, None)
}
} catch {
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
// just return the original plan.
case ae: AnalysisException => filter
case ae: AnalysisException => (aggregateExpressions, None)
}
}

def resolveHaving(filter: Filter, agg: Aggregate): LogicalPlan = {
// Try resolving the condition of the filter as though it is in the aggregate clause
val (aggregateExpressions, transformedAggregateFilter) =
resolveFilterCondInAggregate(filter.condition, agg)

// Push the aggregate expressions into the aggregate (if any).
if (aggregateExpressions.nonEmpty) {
Project(agg.output,
Filter(transformedAggregateFilter.get,
agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions)))
} else {
filter
}
}
}
Expand Down Expand Up @@ -2607,12 +2673,12 @@ class Analyzer(
case Filter(condition, _) if hasWindowFunction(condition) =>
failAnalysis("It is not allowed to use window functions inside WHERE clause")

case AggregateWithHaving(condition, _) if hasWindowFunction(condition) =>
case UnresolvedHaving(condition, _) if hasWindowFunction(condition) =>
failAnalysis("It is not allowed to use window functions inside HAVING clause")

// Aggregate with Having clause. This rule works with an unresolved Aggregate because
// a resolved Aggregate will not have Window Functions.
case f @ AggregateWithHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
case f @ UnresolvedHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
if child.resolved &&
hasWindowFunction(aggregateExprs) &&
a.expressions.forall(_.resolved) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -542,9 +542,9 @@ case class UnresolvedOrdinal(ordinal: Int)
/**
* Represents unresolved aggregate with having clause, it is turned by the analyzer into a Filter.
*/
case class AggregateWithHaving(
case class UnresolvedHaving(
havingCondition: Expression,
child: Aggregate)
child: LogicalPlan)
extends UnaryNode {
override lazy val resolved: Boolean = false
override def output: Seq[Attribute] = child.output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ package object dsl {
groupingExprs: Expression*)(
aggregateExprs: Expression*)(
havingCondition: Expression): LogicalPlan = {
AggregateWithHaving(havingCondition,
UnresolvedHaving(havingCondition,
groupBy(groupingExprs: _*)(aggregateExprs: _*).asInstanceOf[Aggregate])
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -629,12 +629,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
case p: Predicate => p
case e => Cast(e, BooleanType)
}
plan match {
case aggregate: Aggregate =>
AggregateWithHaving(predicate, aggregate)
case _ =>
Filter(predicate, plan)
}
UnresolvedHaving(predicate, plan)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does global aggregate still work? e.g. UnresolvedHaving(Project(agg_func ...))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it still works, the UnresolvedHaving will be changed to Filter in rule ResolveAggregateFunction.

}

/**
Expand Down
7 changes: 6 additions & 1 deletion sql/core/src/test/resources/sql-tests/inputs/having.sql
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,9 @@ SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0);
SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1;

-- SPARK-31519: Cast in having aggregate expressions returns the wrong result
SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10
SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10;

-- SPARK-31663: Grouping sets with having clause returns the wrong result
SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY GROUPING SETS ((b), (a, b)) HAVING b > 10;
SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY CUBE(a, b) HAVING b > 10;
SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY ROLLUP(a, b) HAVING b > 10;
28 changes: 27 additions & 1 deletion sql/core/src/test/resources/sql-tests/results/having.sql.out
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 6
-- Number of queries: 9


-- !query
Expand Down Expand Up @@ -55,3 +55,29 @@ SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2,
struct<b:bigint,fake:date>
-- !query output
2 2020-01-01


-- !query
SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY GROUPING SETS ((b), (a, b)) HAVING b > 10
-- !query schema
struct<b:bigint>
-- !query output
2
2


-- !query
SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY CUBE(a, b) HAVING b > 10
-- !query schema
struct<b:bigint>
-- !query output
2
2


-- !query
SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY ROLLUP(a, b) HAVING b > 10
-- !query schema
struct<b:bigint>
-- !query output
2
90 changes: 90 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3448,6 +3448,96 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
assert(df4.schema.head.name === "randn(1)")
checkIfSeedExistsInExplain(df2)
}

// test("aaa") {
// sql("select sum(a) as b FROM VALUES " +
// "(1, 10), (2, 20) AS T(a, b) group by GROUPING SETS ((b), (a, b)) having b > 10;").show()
// }

// test("aaa") {
// sql("select sum(a) as b FROM VALUES " +
// "(1, 10), (2, 20) AS T(a, b) group by CUBE(a, b) having b > 10;").show()
// }

// test("aaa") {
// val a =
// """
// |CREATE OR REPLACE TEMPORARY VIEW testData AS SELECT * FROM VALUES
// |(1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2)
// |AS testData(a, b);
// """.stripMargin
// sql(a)
// val b =
// """
// |SELECT course, year FROM courseSales GROUP BY CUBE(course, year) HAVING grouping__id > 0;
// """.stripMargin
// sql(b)
// }

// test("aaa") {
// val a =
// """
// |create table gstest2 (a integer, b integer, c integer, d integer,
// | e integer, f integer, g integer, h integer) using parquet;
// """.stripMargin
// sql(a)
// val b =
// """
// |insert into gstest2 values
// | (1, 1, 1, 1, 1, 1, 1, 1),
// | (1, 1, 1, 1, 1, 1, 1, 2),
// | (1, 1, 1, 1, 1, 1, 2, 2),
// | (1, 1, 1, 1, 1, 2, 2, 2),
// | (1, 1, 1, 1, 2, 2, 2, 2),
// | (1, 1, 1, 2, 2, 2, 2, 2),
// | (1, 1, 2, 2, 2, 2, 2, 2),
// | (1, 2, 2, 2, 2, 2, 2, 2),
// | (2, 2, 2, 2, 2, 2, 2, 2);
// """.stripMargin
// sql(b)
// val c =
// """
// |select a,count(*) from gstest2 group by rollup(a) having a is distinct from 1 order by a;
// """.stripMargin
// sql(c)
// }

test("aaa") {
spark
.read
.format("csv")
.options(Map("delimiter" -> "\t", "header" -> "false"))
.schema(
"""
|unique1 int,
|unique2 int,
|two int,
|four int,
|ten int,
|twenty int,
|hundred int,
|thousand int,
|twothousand int,
|fivethous int,
|tenthous int,
|odd int,
|even int,
|stringu1 string,
|stringu2 string,
|string4 string
""".stripMargin)
.load(testFile("test-data/postgresql/onek.data"))
.write
.format("parquet")
.saveAsTable("onek")
sql(
"""
|select ten, sum(distinct four) from onek a
|group by grouping sets((ten,four),(ten))
|having exists (select 1 from onek b where sum(distinct a.four) = b.four)
""".stripMargin)
spark.sql("DROP TABLE IF EXISTS onek")
}
}

case class Foo(bar: Option[String])