Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -318,17 +318,31 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {

/**
* Statically evaluate an expression containing zero or more placeholders, given a set
* of bindings for placeholder values.
* of bindings for placeholder values, if the expression is evaluable. If it is not,
* bind statically evaluated expression results to an expression.
*/
private def evalExpr(expr: Expression, bindings: Map[ExprId, Option[Any]]) : Option[Any] = {
private def bindingExpr(
expr: Expression,
bindings: Map[ExprId, Option[Expression]]): Option[Expression] = {
val rewrittenExpr = expr transform {
case r: AttributeReference =>
bindings(r.exprId) match {
case Some(v) => Literal.create(v, r.dataType)
case Some(v) => v
case None => Literal.default(NullType)
}
}
Option(rewrittenExpr.eval())
if (!rewrittenExpr.foldable) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we apply the check in more places? evalAggOnZeroTups also calls eval() directly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. this is not possible for PythonUDF, but it is potential for other not foldable expression.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it is not covered by added test. Let me add test for it...

// SPARK-28441: Some expressions, like PythonUDF, can't be statically evaluated.
// Needs to evaluate them on query runtime.
Some(rewrittenExpr)
} else {
val exprVal = rewrittenExpr.eval()
if (exprVal == null) {
None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why we need to return None here instead of a null literal?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it uses None to make checking bindings easier.

In other way, to use null literal, Option[Expression] can be changed to Expression in methods like evalSubqueryOnZeroTups, evalPlan. Then we check bindings by literal instead of None. Good thing is we can write Literal.create(rewrittenExpr.eval(), expr.dataType), instead of checking null. Looks like just a choice problem.

} else {
Some(Literal.create(exprVal, expr.dataType))
}
}
}

/**
Expand All @@ -354,27 +368,21 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
* [[org.apache.spark.sql.catalyst.analysis.CheckAnalysis]]. If the checks in
* CheckAnalysis become less restrictive, this method will need to change.
*/
private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Any] = {
private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Expression] = {
// Inputs to this method will start with a chain of zero or more SubqueryAlias
// and Project operators, followed by an optional Filter, followed by an
// Aggregate. Traverse the operators recursively.
def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Any]] = lp match {
def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Expression]] = lp match {
case SubqueryAlias(_, child) => evalPlan(child)
case Filter(condition, child) =>
val bindings = evalPlan(child)
if (bindings.isEmpty) bindings
else {
val exprResult = evalExpr(condition, bindings).getOrElse(false)
.asInstanceOf[Boolean]
if (exprResult) bindings else Map.empty
}
evalPlan(child)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we evaluate the filter condition?


case Project(projectList, child) =>
val bindings = evalPlan(child)
if (bindings.isEmpty) {
bindings
} else {
projectList.map(ne => (ne.exprId, evalExpr(ne, bindings))).toMap
projectList.map(ne => (ne.exprId, bindingExpr(ne, bindings))).toMap
}

case Aggregate(_, aggExprs, _) =>
Expand All @@ -384,7 +392,13 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
aggExprs.map {
case ref: AttributeReference => (ref.exprId, None)
case alias @ Alias(_: AttributeReference, _) => (alias.exprId, None)
case ne => (ne.exprId, evalAggOnZeroTups(ne))
case ne =>
val aggEval = evalAggOnZeroTups(ne)
if (aggEval.isEmpty) {
(ne.exprId, None)
} else {
(ne.exprId, Some(Literal.create(evalAggOnZeroTups(ne).get, ne.dataType)))
}
}.toMap

case _ =>
Expand Down Expand Up @@ -473,7 +487,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
currentChild.output :+
Alias(
If(IsNull(alwaysTrueRef),
Literal.create(resultWithZeroTups.get, origOutput.dataType),
resultWithZeroTups.get,
aggValRef), origOutput.name)(exprId = origOutput.exprId),
Join(currentChild,
Project(query.output :+ alwaysTrueExpr, query),
Expand All @@ -494,11 +508,11 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
case op => sys.error(s"Unexpected operator $op in corelated subquery")
}

// CASE WHEN alwayTrue IS NULL THEN resultOnZeroTups
// CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups
// WHEN NOT (original HAVING clause expr) THEN CAST(null AS <type of aggVal>)
// ELSE (aggregate value) END AS (original column name)
val caseExpr = Alias(CaseWhen(Seq(
(IsNull(alwaysTrueRef), Literal.create(resultWithZeroTups.get, origOutput.dataType)),
(IsNull(alwaysTrueRef), resultWithZeroTups.get),
(Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))),
aggValRef),
origOutput.name)(exprId = origOutput.exprId)
Expand Down
128 changes: 128 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1384,4 +1384,132 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
assert(subqueryExecs.forall(_.name.startsWith("scalar-subquery#")),
"SubqueryExec name should start with scalar-subquery#")
}

test("SPARK-28441: COUNT bug in WHERE clause (Filter) with PythonUDF") {
import IntegratedUDFTestUtils._

val pythonTestUDF = TestPythonUDF(name = "udf")
registerTestUDF(pythonTestUDF, spark)

// Case 1: Canonical example of the COUNT bug
checkAnswer(
sql("select l.a from l where (select udf(count(*)) from r where l.a = r.c) < l.a"),
Row(1) :: Row(1) :: Row(3) :: Row(6) :: Nil)
// Case 2: count(*) = 0; could be rewritten to NOT EXISTS but currently uses
// a rewrite that is vulnerable to the COUNT bug
checkAnswer(
sql("select l.a from l where (select udf(count(*)) from r where l.a = r.c) = 0"),
Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
// Case 3: COUNT bug without a COUNT aggregate
checkAnswer(
sql("select l.a from l where (select udf(sum(r.d)) is null from r where l.a = r.c)"),
Row(1) :: Row(1) ::Row(null) :: Row(null) :: Row(6) :: Nil)
}

test("SPARK-28441: COUNT bug in SELECT clause (Project) with PythonUDF") {
import IntegratedUDFTestUtils._

val pythonTestUDF = TestPythonUDF(name = "udf")
registerTestUDF(pythonTestUDF, spark)

checkAnswer(
sql("select a, (select udf(count(*)) from r where l.a = r.c) as cnt from l"),
Row(1, 0) :: Row(1, 0) :: Row(2, 2) :: Row(2, 2) :: Row(3, 1) :: Row(null, 0)
:: Row(null, 0) :: Row(6, 1) :: Nil)
}

test("SPARK-28441: COUNT bug in HAVING clause (Filter) with PythonUDF") {
import IntegratedUDFTestUtils._

val pythonTestUDF = TestPythonUDF(name = "udf")
registerTestUDF(pythonTestUDF, spark)

checkAnswer(
sql("select l.a as grp_a from l group by l.a " +
"having (select udf(count(*)) from r where grp_a = r.c) = 0 " +
"order by grp_a"),
Row(null) :: Row(1) :: Nil)
}

test("SPARK-28441: COUNT bug in Aggregate with PythonUDF") {
import IntegratedUDFTestUtils._

val pythonTestUDF = TestPythonUDF(name = "udf")
registerTestUDF(pythonTestUDF, spark)

checkAnswer(
sql("select l.a as aval, sum((select udf(count(*)) from r where l.a = r.c)) as cnt " +
"from l group by l.a order by aval"),
Row(null, 0) :: Row(1, 0) :: Row(2, 4) :: Row(3, 1) :: Row(6, 1) :: Nil)
}

test("SPARK-28441: COUNT bug negative examples with PythonUDF") {
import IntegratedUDFTestUtils._

val pythonTestUDF = TestPythonUDF(name = "udf")
registerTestUDF(pythonTestUDF, spark)

// Case 1: Potential COUNT bug case that was working correctly prior to the fix
checkAnswer(
sql("select l.a from l where (select udf(sum(r.d)) from r where l.a = r.c) is null"),
Row(1) :: Row(1) :: Row(null) :: Row(null) :: Row(6) :: Nil)
// Case 2: COUNT aggregate but no COUNT bug due to > 0 test.
checkAnswer(
sql("select l.a from l where (select udf(count(*)) from r where l.a = r.c) > 0"),
Row(2) :: Row(2) :: Row(3) :: Row(6) :: Nil)
// Case 3: COUNT inside aggregate expression but no COUNT bug.
checkAnswer(
sql("select l.a from l where (select udf(count(*)) + udf(sum(r.d)) " +
"from r where l.a = r.c) = 0"),
Nil)
}

test("SPARK-28441: COUNT bug in subquery in subquery in subquery with PythonUDF") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe just say in nested subquery

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. it was copied from old test.

import IntegratedUDFTestUtils._

val pythonTestUDF = TestPythonUDF(name = "udf")
registerTestUDF(pythonTestUDF, spark)

checkAnswer(
sql("""select l.a from l
|where (
| select cntPlusOne + 1 as cntPlusTwo from (
| select cnt + 1 as cntPlusOne from (
| select udf(sum(r.c)) s, udf(count(*)) cnt from r where l.a = r.c
| having cnt = 0
| )
| )
|) = 2""".stripMargin),
Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
}

test("SPARK-28441: COUNT bug with nasty predicate expr with PythonUDF") {
import IntegratedUDFTestUtils._

val pythonTestUDF = TestPythonUDF(name = "udf")
registerTestUDF(pythonTestUDF, spark)

checkAnswer(
sql("select l.a from l where " +
"(select case when udf(count(*)) = 1 then null else udf(count(*)) end as cnt " +
"from r where l.a = r.c) = 0"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use multi-line string to write long SQL? Let's also upper case the keywords.

Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil)
}

test("SPARK-28441: COUNT bug with attribute ref in subquery input and output with PythonUDF") {
import IntegratedUDFTestUtils._

val pythonTestUDF = TestPythonUDF(name = "udf")
registerTestUDF(pythonTestUDF, spark)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, we should add assume(shouldTestPythonUDFs). Maybe it's not a biggie in general but it can matter in other venders' testing base. For instance, if somebody launches a test in a minimal docker image, it might make the tests failed suddenly.

This skipping stuff isn't completely new in our test base. See TestUtils.testCommandAvailable for instance.


checkAnswer(
sql(
"""
|select l.b, (select (r.c + udf(count(*))) is null
|from r
|where l.a = r.c group by r.c) from l
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's format the SQL in a more readable way. For this particular example

select
  l.b,
  (
    select (r.c + udf(count(*))) is null
    from r
    where l.a = r.c
    group by r.c
  )
from l

""".stripMargin),
Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) ::
Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil)
}
}