Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1232,13 +1232,13 @@ class Analyzer(
/**
* Resolves the attribute and extract value expressions(s) by traversing the
* input expression in top down manner. The traversal is done in top-down manner as
* we need to skip over unbound lamda function expression. The lamda expressions are
* we need to skip over unbound lambda function expression. The lambda expressions are
* resolved in a different rule [[ResolveLambdaVariables]]
*
* Example :
* SELECT transform(array(1, 2, 3), (x, i) -> x + i)"
*
* In the case above, x and i are resolved as lamda variables in [[ResolveLambdaVariables]]
* In the case above, x and i are resolved as lambda variables in [[ResolveLambdaVariables]]
*
* Note : In this routine, the unresolved attributes are resolved from the input plan's
* children attributes.
Expand Down Expand Up @@ -1393,6 +1393,9 @@ class Analyzer(
notMatchedActions = newNotMatchedActions)
}

// Skip the having clause here, this will be handled in ResolveAggregateFunctions.
case h: AggregateWithHaving => h

case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString(SQLConf.get.maxToStringFields)}")
q.mapExpressions(resolveExpressionTopDown(_, q))
Expand Down Expand Up @@ -2033,62 +2036,14 @@ class Analyzer(
*/
object ResolveAggregateFunctions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case f @ Filter(cond, agg @ Aggregate(grouping, originalAggExprs, child)) if agg.resolved =>
// Resolve aggregate with having clause to Filter(..., Aggregate()). Note, to avoid wrongly
// resolve the having condition expression, here we skip resolving it in ResolveReferences
// and transform it to Filter after aggregate is resolved. See more details in SPARK-31519.
case AggregateWithHaving(cond, agg: Aggregate) if agg.resolved =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you leave some comments here about why we need this special handling for aggregate with having?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, done in c75fe08.

resolveHaving(Filter(cond, agg), agg)

// Try resolving the condition of the filter as though it is in the aggregate clause
try {
val aggregatedCondition =
Aggregate(
grouping,
Alias(cond, "havingCondition")() :: Nil,
child)
val resolvedOperator = executeSameContext(aggregatedCondition)
def resolvedAggregateFilter =
resolvedOperator
.asInstanceOf[Aggregate]
.aggregateExpressions.head

// If resolution was successful and we see the filter has an aggregate in it, add it to
// the original aggregate operator.
if (resolvedOperator.resolved) {
// Try to replace all aggregate expressions in the filter by an alias.
val aggregateExpressions = ArrayBuffer.empty[NamedExpression]
val transformedAggregateFilter = resolvedAggregateFilter.transform {
case ae: AggregateExpression =>
val alias = Alias(ae, ae.toString)()
aggregateExpressions += alias
alias.toAttribute
// Grouping functions are handled in the rule [[ResolveGroupingAnalytics]].
case e: Expression if grouping.exists(_.semanticEquals(e)) &&
!ResolveGroupingAnalytics.hasGroupingFunction(e) &&
!agg.output.exists(_.semanticEquals(e)) =>
e match {
case ne: NamedExpression =>
aggregateExpressions += ne
ne.toAttribute
case _ =>
val alias = Alias(e, e.toString)()
aggregateExpressions += alias
alias.toAttribute
}
}

// Push the aggregate expressions into the aggregate (if any).
if (aggregateExpressions.nonEmpty) {
Project(agg.output,
Filter(transformedAggregateFilter,
agg.copy(aggregateExpressions = originalAggExprs ++ aggregateExpressions)))
} else {
f
}
} else {
f
}
} catch {
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
// just return the original plan.
case ae: AnalysisException => f
}
case f @ Filter(_, agg: Aggregate) if agg.resolved =>
resolveHaving(f, agg)

case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved =>

Expand Down Expand Up @@ -2159,6 +2114,63 @@ class Analyzer(
def containsAggregate(condition: Expression): Boolean = {
condition.find(_.isInstanceOf[AggregateExpression]).isDefined
}

def resolveHaving(filter: Filter, agg: Aggregate): LogicalPlan = {
// Try resolving the condition of the filter as though it is in the aggregate clause
try {
val aggregatedCondition =
Aggregate(
agg.groupingExpressions,
Alias(filter.condition, "havingCondition")() :: Nil,
agg.child)
val resolvedOperator = executeSameContext(aggregatedCondition)
def resolvedAggregateFilter =
resolvedOperator
.asInstanceOf[Aggregate]
.aggregateExpressions.head

// If resolution was successful and we see the filter has an aggregate in it, add it to
// the original aggregate operator.
if (resolvedOperator.resolved) {
// Try to replace all aggregate expressions in the filter by an alias.
val aggregateExpressions = ArrayBuffer.empty[NamedExpression]
val transformedAggregateFilter = resolvedAggregateFilter.transform {
case ae: AggregateExpression =>
val alias = Alias(ae, ae.toString)()
aggregateExpressions += alias
alias.toAttribute
// Grouping functions are handled in the rule [[ResolveGroupingAnalytics]].
case e: Expression if agg.groupingExpressions.exists(_.semanticEquals(e)) &&
!ResolveGroupingAnalytics.hasGroupingFunction(e) &&
!agg.output.exists(_.semanticEquals(e)) =>
e match {
case ne: NamedExpression =>
aggregateExpressions += ne
ne.toAttribute
case _ =>
val alias = Alias(e, e.toString)()
aggregateExpressions += alias
alias.toAttribute
}
}

// Push the aggregate expressions into the aggregate (if any).
if (aggregateExpressions.nonEmpty) {
Project(agg.output,
Filter(transformedAggregateFilter,
agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions)))
} else {
filter
}
} else {
filter
}
} catch {
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
// just return the original plan.
case ae: AnalysisException => filter
}
}
}

/**
Expand Down Expand Up @@ -2583,11 +2595,14 @@ class Analyzer(
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown {

case Filter(condition, _) if hasWindowFunction(condition) =>
failAnalysis("It is not allowed to use window functions inside WHERE and HAVING clauses")
failAnalysis("It is not allowed to use window functions inside WHERE clause")

case AggregateWithHaving(condition, _) if hasWindowFunction(condition) =>
failAnalysis("It is not allowed to use window functions inside HAVING clause")

// Aggregate with Having clause. This rule works with an unresolved Aggregate because
// a resolved Aggregate will not have Window Functions.
case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
case f @ AggregateWithHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
if child.resolved &&
hasWindowFunction(aggregateExprs) &&
a.expressions.forall(_.resolved) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.parser.ParserUtils
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, UnaryNode}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan, UnaryNode}
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.util.quoteIdentifier
import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog}
Expand Down Expand Up @@ -538,3 +538,14 @@ case class UnresolvedOrdinal(ordinal: Int)
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false
}

/**
* Represents unresolved aggregate with having clause, it is turned by the analyzer into a Filter.
*/
case class AggregateWithHaving(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UnresolvedHaving?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also discussed the naming here: #28294 (comment)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does GroupingSets have this bug as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think GroupingSets don't have the same bug here.
But I test the 2 below queries:

select sum(a) as b, '2020-01-01' as fake
FROM VALUES (1, 10), (2, 20) AS T(a, b)
group by GROUPING SETS ((b), (a, b))
having b > 10;

and adding the cast

select sum(a) as b, cast('2020-01-01' as date) as fake
FROM VALUES (1, 10), (2, 20) AS T(a, b)
group by GROUPING SETS ((b), (a, b))
having b > 10;

Both queries return empty results, it seems the HAVING after GROUPING SET resolve expression in select clauses first. Maybe it's another bug, I think we should follow the resolving strategy of table scheme first and then select clause?

havingCondition: Expression,
child: Aggregate)
extends UnaryNode {
override lazy val resolved: Boolean = false
override def output: Seq[Attribute] = child.output
}
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,14 @@ package object dsl {
Aggregate(groupingExprs, aliasedExprs, logicalPlan)
}

def having(
groupingExprs: Expression*)(
aggregateExprs: Expression*)(
havingCondition: Expression): LogicalPlan = {
AggregateWithHaving(havingCondition,
groupBy(groupingExprs: _*)(aggregateExprs: _*).asInstanceOf[Aggregate])
}

def window(
windowExpressions: Seq[NamedExpression],
partitionSpec: Seq[Expression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,12 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
case p: Predicate => p
case e => Cast(e, BooleanType)
}
Filter(predicate, plan)
plan match {
case aggregate: Aggregate =>
AggregateWithHaving(predicate, aggregate)
case _ =>
Filter(predicate, plan)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if we also create having here? This is for global aggregate, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also for GROUPING SET.

}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ class PlanParserSuite extends AnalysisTest {
assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x < 1).select('a, 'b))
assertEqual(
"select a, b from db.c having x < 1",
table("db", "c").groupBy()('a, 'b).where('x < 1))
table("db", "c").having()('a, 'b)('x < 1))
assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b)))
assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b))
assertEqual("select from tbl", OneRowRelation().select('from.as("tbl")))
Expand Down Expand Up @@ -574,8 +574,7 @@ class PlanParserSuite extends AnalysisTest {
assertEqual(
"select g from t group by g having a > (select b from s)",
table("t")
.groupBy('g)('g)
.where('a > ScalarSubquery(table("s").select('b))))
.having('g)('g)('a > ScalarSubquery(table("s").select('b))))
}

test("table reference") {
Expand Down
3 changes: 3 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/having.sql
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@ SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0);

-- SPARK-20329: make sure we handle timezones correctly
SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1;

-- SPARK-31519: Cast in having aggregate expressions returns the wrong result
SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10
10 changes: 9 additions & 1 deletion sql/core/src/test/resources/sql-tests/results/having.sql.out
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 5
-- Number of queries: 6


-- !query
Expand Down Expand Up @@ -47,3 +47,11 @@ struct<(a + CAST(b AS BIGINT)):bigint>
-- !query output
3
7


-- !query
SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10
-- !query schema
struct<b:bigint,fake:date>
-- !query output
2 2020-01-01
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ SELECT * FROM empsalary WHERE row_number() OVER (ORDER BY salary) < 10
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
It is not allowed to use window functions inside WHERE and HAVING clauses;
It is not allowed to use window functions inside WHERE clause;


-- !query
Expand Down Expand Up @@ -341,7 +341,7 @@ SELECT * FROM empsalary WHERE (rank() OVER (ORDER BY random())) > 10
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
It is not allowed to use window functions inside WHERE and HAVING clauses;
It is not allowed to use window functions inside WHERE clause;


-- !query
Expand All @@ -350,7 +350,7 @@ SELECT * FROM empsalary WHERE rank() OVER (ORDER BY random())
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
It is not allowed to use window functions inside WHERE and HAVING clauses;
It is not allowed to use window functions inside WHERE clause;


-- !query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -665,40 +665,46 @@ class DataFrameWindowFunctionsSuite extends QueryTest
}

test("SPARK-24575: Window functions inside WHERE and HAVING clauses") {
def checkAnalysisError(df: => DataFrame): Unit = {
def checkAnalysisError(df: => DataFrame, clause: String): Unit = {
val thrownException = the[AnalysisException] thrownBy {
df.queryExecution.analyzed
}
assert(thrownException.message.contains("window functions inside WHERE and HAVING clauses"))
assert(thrownException.message.contains(s"window functions inside $clause clause"))
}

checkAnalysisError(testData2.select("a").where(rank().over(Window.orderBy($"b")) === 1))
checkAnalysisError(testData2.where($"b" === 2 && rank().over(Window.orderBy($"b")) === 1))
checkAnalysisError(
testData2.select("a").where(rank().over(Window.orderBy($"b")) === 1), "WHERE")
checkAnalysisError(
testData2.where($"b" === 2 && rank().over(Window.orderBy($"b")) === 1), "WHERE")
checkAnalysisError(
testData2.groupBy($"a")
.agg(avg($"b").as("avgb"))
.where($"a" > $"avgb" && rank().over(Window.orderBy($"a")) === 1))
.where($"a" > $"avgb" && rank().over(Window.orderBy($"a")) === 1), "WHERE")
checkAnalysisError(
testData2.groupBy($"a")
.agg(max($"b").as("maxb"), sum($"b").as("sumb"))
.where(rank().over(Window.orderBy($"a")) === 1))
.where(rank().over(Window.orderBy($"a")) === 1), "WHERE")
checkAnalysisError(
testData2.groupBy($"a")
.agg(max($"b").as("maxb"), sum($"b").as("sumb"))
.where($"sumb" === 5 && rank().over(Window.orderBy($"a")) === 1))
.where($"sumb" === 5 && rank().over(Window.orderBy($"a")) === 1), "WHERE")

checkAnalysisError(sql("SELECT a FROM testData2 WHERE RANK() OVER(ORDER BY b) = 1"))
checkAnalysisError(sql("SELECT * FROM testData2 WHERE b = 2 AND RANK() OVER(ORDER BY b) = 1"))
checkAnalysisError(sql("SELECT a FROM testData2 WHERE RANK() OVER(ORDER BY b) = 1"), "WHERE")
checkAnalysisError(
sql("SELECT * FROM testData2 WHERE b = 2 AND RANK() OVER(ORDER BY b) = 1"), "WHERE")
checkAnalysisError(
sql("SELECT * FROM testData2 GROUP BY a HAVING a > AVG(b) AND RANK() OVER(ORDER BY a) = 1"))
sql("SELECT * FROM testData2 GROUP BY a HAVING a > AVG(b) AND RANK() OVER(ORDER BY a) = 1"),
"HAVING")
checkAnalysisError(
sql("SELECT a, MAX(b), SUM(b) FROM testData2 GROUP BY a HAVING RANK() OVER(ORDER BY a) = 1"))
sql("SELECT a, MAX(b), SUM(b) FROM testData2 GROUP BY a HAVING RANK() OVER(ORDER BY a) = 1"),
"HAVING")
checkAnalysisError(
sql(
s"""SELECT a, MAX(b)
|FROM testData2
|GROUP BY a
|HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin))
|HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin),
"HAVING")
}

test("window functions in multiple selects") {
Expand Down