-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-28441][SQL][Python] Fix error when non-foldable expression is used in correlated scalar subquery #25204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
725304c
7972d7c
33441a3
110a39e
a7803f5
0158d85
2dd29c1
9aea844
1f6b717
d7d023d
fd29677
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -318,17 +318,31 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { | |
|
|
||
| /** | ||
| * Statically evaluate an expression containing zero or more placeholders, given a set | ||
| * of bindings for placeholder values. | ||
| * of bindings for placeholder values, if the expression is evaluable. If it is not, | ||
| * bind statically evaluated expression results to an expression. | ||
| */ | ||
| private def evalExpr(expr: Expression, bindings: Map[ExprId, Option[Any]]) : Option[Any] = { | ||
| private def bindingExpr( | ||
| expr: Expression, | ||
| bindings: Map[ExprId, Option[Expression]]): Option[Expression] = { | ||
| val rewrittenExpr = expr transform { | ||
| case r: AttributeReference => | ||
| bindings(r.exprId) match { | ||
| case Some(v) => Literal.create(v, r.dataType) | ||
| case Some(v) => v | ||
| case None => Literal.default(NullType) | ||
| } | ||
| } | ||
| Option(rewrittenExpr.eval()) | ||
| if (!rewrittenExpr.foldable) { | ||
| // SPARK-28441: Some expressions, like PythonUDF, can't be statically evaluated. | ||
| // Needs to evaluate them on query runtime. | ||
| Some(rewrittenExpr) | ||
| } else { | ||
| val exprVal = rewrittenExpr.eval() | ||
| if (exprVal == null) { | ||
| None | ||
|
||
| } else { | ||
| Some(Literal.create(exprVal, expr.dataType)) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -354,27 +368,21 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { | |
| * [[org.apache.spark.sql.catalyst.analysis.CheckAnalysis]]. If the checks in | ||
| * CheckAnalysis become less restrictive, this method will need to change. | ||
| */ | ||
| private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Any] = { | ||
| private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Expression] = { | ||
| // Inputs to this method will start with a chain of zero or more SubqueryAlias | ||
| // and Project operators, followed by an optional Filter, followed by an | ||
| // Aggregate. Traverse the operators recursively. | ||
| def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Any]] = lp match { | ||
| def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Expression]] = lp match { | ||
| case SubqueryAlias(_, child) => evalPlan(child) | ||
| case Filter(condition, child) => | ||
| val bindings = evalPlan(child) | ||
| if (bindings.isEmpty) bindings | ||
| else { | ||
| val exprResult = evalExpr(condition, bindings).getOrElse(false) | ||
| .asInstanceOf[Boolean] | ||
| if (exprResult) bindings else Map.empty | ||
| } | ||
| evalPlan(child) | ||
|
||
|
|
||
| case Project(projectList, child) => | ||
| val bindings = evalPlan(child) | ||
| if (bindings.isEmpty) { | ||
| bindings | ||
| } else { | ||
| projectList.map(ne => (ne.exprId, evalExpr(ne, bindings))).toMap | ||
| projectList.map(ne => (ne.exprId, bindingExpr(ne, bindings))).toMap | ||
| } | ||
|
|
||
| case Aggregate(_, aggExprs, _) => | ||
|
|
@@ -384,7 +392,13 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { | |
| aggExprs.map { | ||
| case ref: AttributeReference => (ref.exprId, None) | ||
| case alias @ Alias(_: AttributeReference, _) => (alias.exprId, None) | ||
| case ne => (ne.exprId, evalAggOnZeroTups(ne)) | ||
| case ne => | ||
| val aggEval = evalAggOnZeroTups(ne) | ||
| if (aggEval.isEmpty) { | ||
| (ne.exprId, None) | ||
| } else { | ||
| (ne.exprId, Some(Literal.create(evalAggOnZeroTups(ne).get, ne.dataType))) | ||
| } | ||
| }.toMap | ||
|
|
||
| case _ => | ||
|
|
@@ -473,7 +487,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { | |
| currentChild.output :+ | ||
| Alias( | ||
| If(IsNull(alwaysTrueRef), | ||
| Literal.create(resultWithZeroTups.get, origOutput.dataType), | ||
| resultWithZeroTups.get, | ||
| aggValRef), origOutput.name)(exprId = origOutput.exprId), | ||
| Join(currentChild, | ||
| Project(query.output :+ alwaysTrueExpr, query), | ||
|
|
@@ -494,11 +508,11 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { | |
| case op => sys.error(s"Unexpected operator $op in corelated subquery") | ||
| } | ||
|
|
||
| // CASE WHEN alwayTrue IS NULL THEN resultOnZeroTups | ||
| // CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups | ||
| // WHEN NOT (original HAVING clause expr) THEN CAST(null AS <type of aggVal>) | ||
| // ELSE (aggregate value) END AS (original column name) | ||
| val caseExpr = Alias(CaseWhen(Seq( | ||
| (IsNull(alwaysTrueRef), Literal.create(resultWithZeroTups.get, origOutput.dataType)), | ||
| (IsNull(alwaysTrueRef), resultWithZeroTups.get), | ||
| (Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))), | ||
| aggValRef), | ||
| origOutput.name)(exprId = origOutput.exprId) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1384,4 +1384,132 @@ class SubquerySuite extends QueryTest with SharedSQLContext { | |
| assert(subqueryExecs.forall(_.name.startsWith("scalar-subquery#")), | ||
| "SubqueryExec name should start with scalar-subquery#") | ||
| } | ||
|
|
||
| test("SPARK-28441: COUNT bug in WHERE clause (Filter) with PythonUDF") { | ||
| import IntegratedUDFTestUtils._ | ||
|
|
||
| val pythonTestUDF = TestPythonUDF(name = "udf") | ||
| registerTestUDF(pythonTestUDF, spark) | ||
|
|
||
| // Case 1: Canonical example of the COUNT bug | ||
| checkAnswer( | ||
| sql("select l.a from l where (select udf(count(*)) from r where l.a = r.c) < l.a"), | ||
| Row(1) :: Row(1) :: Row(3) :: Row(6) :: Nil) | ||
| // Case 2: count(*) = 0; could be rewritten to NOT EXISTS but currently uses | ||
| // a rewrite that is vulnerable to the COUNT bug | ||
| checkAnswer( | ||
| sql("select l.a from l where (select udf(count(*)) from r where l.a = r.c) = 0"), | ||
| Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) | ||
| // Case 3: COUNT bug without a COUNT aggregate | ||
| checkAnswer( | ||
| sql("select l.a from l where (select udf(sum(r.d)) is null from r where l.a = r.c)"), | ||
| Row(1) :: Row(1) ::Row(null) :: Row(null) :: Row(6) :: Nil) | ||
| } | ||
|
|
||
| test("SPARK-28441: COUNT bug in SELECT clause (Project) with PythonUDF") { | ||
| import IntegratedUDFTestUtils._ | ||
|
|
||
| val pythonTestUDF = TestPythonUDF(name = "udf") | ||
| registerTestUDF(pythonTestUDF, spark) | ||
|
|
||
| checkAnswer( | ||
| sql("select a, (select udf(count(*)) from r where l.a = r.c) as cnt from l"), | ||
| Row(1, 0) :: Row(1, 0) :: Row(2, 2) :: Row(2, 2) :: Row(3, 1) :: Row(null, 0) | ||
| :: Row(null, 0) :: Row(6, 1) :: Nil) | ||
| } | ||
|
|
||
| test("SPARK-28441: COUNT bug in HAVING clause (Filter) with PythonUDF") { | ||
| import IntegratedUDFTestUtils._ | ||
|
|
||
| val pythonTestUDF = TestPythonUDF(name = "udf") | ||
| registerTestUDF(pythonTestUDF, spark) | ||
|
|
||
| checkAnswer( | ||
| sql("select l.a as grp_a from l group by l.a " + | ||
| "having (select udf(count(*)) from r where grp_a = r.c) = 0 " + | ||
| "order by grp_a"), | ||
| Row(null) :: Row(1) :: Nil) | ||
| } | ||
|
|
||
| test("SPARK-28441: COUNT bug in Aggregate with PythonUDF") { | ||
| import IntegratedUDFTestUtils._ | ||
|
|
||
| val pythonTestUDF = TestPythonUDF(name = "udf") | ||
| registerTestUDF(pythonTestUDF, spark) | ||
|
|
||
| checkAnswer( | ||
| sql("select l.a as aval, sum((select udf(count(*)) from r where l.a = r.c)) as cnt " + | ||
| "from l group by l.a order by aval"), | ||
| Row(null, 0) :: Row(1, 0) :: Row(2, 4) :: Row(3, 1) :: Row(6, 1) :: Nil) | ||
| } | ||
|
|
||
| test("SPARK-28441: COUNT bug negative examples with PythonUDF") { | ||
| import IntegratedUDFTestUtils._ | ||
|
|
||
| val pythonTestUDF = TestPythonUDF(name = "udf") | ||
| registerTestUDF(pythonTestUDF, spark) | ||
|
|
||
| // Case 1: Potential COUNT bug case that was working correctly prior to the fix | ||
| checkAnswer( | ||
| sql("select l.a from l where (select udf(sum(r.d)) from r where l.a = r.c) is null"), | ||
| Row(1) :: Row(1) :: Row(null) :: Row(null) :: Row(6) :: Nil) | ||
| // Case 2: COUNT aggregate but no COUNT bug due to > 0 test. | ||
| checkAnswer( | ||
| sql("select l.a from l where (select udf(count(*)) from r where l.a = r.c) > 0"), | ||
| Row(2) :: Row(2) :: Row(3) :: Row(6) :: Nil) | ||
| // Case 3: COUNT inside aggregate expression but no COUNT bug. | ||
| checkAnswer( | ||
| sql("select l.a from l where (select udf(count(*)) + udf(sum(r.d)) " + | ||
| "from r where l.a = r.c) = 0"), | ||
| Nil) | ||
| } | ||
|
|
||
| test("SPARK-28441: COUNT bug in subquery in subquery in subquery with PythonUDF") { | ||
|
||
| import IntegratedUDFTestUtils._ | ||
|
|
||
| val pythonTestUDF = TestPythonUDF(name = "udf") | ||
| registerTestUDF(pythonTestUDF, spark) | ||
|
|
||
| checkAnswer( | ||
| sql("""select l.a from l | ||
| |where ( | ||
| | select cntPlusOne + 1 as cntPlusTwo from ( | ||
| | select cnt + 1 as cntPlusOne from ( | ||
| | select udf(sum(r.c)) s, udf(count(*)) cnt from r where l.a = r.c | ||
| | having cnt = 0 | ||
| | ) | ||
| | ) | ||
| |) = 2""".stripMargin), | ||
| Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) | ||
| } | ||
|
|
||
| test("SPARK-28441: COUNT bug with nasty predicate expr with PythonUDF") { | ||
| import IntegratedUDFTestUtils._ | ||
|
|
||
| val pythonTestUDF = TestPythonUDF(name = "udf") | ||
| registerTestUDF(pythonTestUDF, spark) | ||
|
|
||
| checkAnswer( | ||
| sql("select l.a from l where " + | ||
| "(select case when udf(count(*)) = 1 then null else udf(count(*)) end as cnt " + | ||
| "from r where l.a = r.c) = 0"), | ||
|
||
| Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) | ||
| } | ||
|
|
||
| test("SPARK-28441: COUNT bug with attribute ref in subquery input and output with PythonUDF") { | ||
| import IntegratedUDFTestUtils._ | ||
|
|
||
| val pythonTestUDF = TestPythonUDF(name = "udf") | ||
| registerTestUDF(pythonTestUDF, spark) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW, we should add This skipping stuff isn't completely new in our test base. See |
||
|
|
||
| checkAnswer( | ||
| sql( | ||
| """ | ||
| |select l.b, (select (r.c + udf(count(*))) is null | ||
| |from r | ||
| |where l.a = r.c group by r.c) from l | ||
|
||
| """.stripMargin), | ||
| Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) :: | ||
| Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we apply the check in more places?
evalAggOnZeroTupsalso callseval()directly.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes. this is not possible for PythonUDF, but it is potential for other not foldable expression.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so it is not covered by added test. Let me add test for it...