diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e55cdfedd3234..acd8c052e96e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2698,7 +2698,7 @@ object EliminateUnions extends Rule[LogicalPlan] { * rule can't work for those parameters. */ object CleanupAliases extends Rule[LogicalPlan] { - private def trimAliases(e: Expression): Expression = { + def trimAliases(e: Expression): Expression = { e.transformDown { case Alias(child, _) => child case MultiAlias(child, _) => child diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c99d2c06fac63..0cb4c71510119 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -650,7 +650,9 @@ object ColumnPruning extends Rule[LogicalPlan] { */ private def removeProjectBeforeFilter(plan: LogicalPlan): LogicalPlan = plan transformUp { case p1 @ Project(_, f @ Filter(_, p2 @ Project(_, child))) - if p2.outputSet.subsetOf(child.outputSet) => + if p2.outputSet.subsetOf(child.outputSet) && + // We only remove attribute-only project. + p2.projectList.forall(_.isInstanceOf[AttributeReference]) => p1.copy(child = f.copy(child = child)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index e78ed1c3c5d94..4f7333c3875f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.CleanupAliases import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -316,25 +317,41 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { newExpression.asInstanceOf[E] } + /** + * Checks if given expression is foldable. Evaluates it and returns it as literal, if yes. + * If not, returns the original expression without evaluation. + */ + private def tryEvalExpr(expr: Expression): Expression = { + // Removes Alias over given expression, because Alias is not foldable. + if (!CleanupAliases.trimAliases(expr).foldable) { + // SPARK-28441: Some expressions, like PythonUDF, can't be statically evaluated. + // Needs to evaluate them on query runtime. + expr + } else { + Literal.create(expr.eval(), expr.dataType) + } + } + /** * Statically evaluate an expression containing zero or more placeholders, given a set - * of bindings for placeholder values. + * of bindings for placeholder values, if the expression is evaluable. If it is not, + * bind statically evaluated expression results to an expression. */ - private def evalExpr(expr: Expression, bindings: Map[ExprId, Option[Any]]) : Option[Any] = { + private def bindingExpr( + expr: Expression, + bindings: Map[ExprId, Expression]): Expression = { val rewrittenExpr = expr transform { case r: AttributeReference => - bindings(r.exprId) match { - case Some(v) => Literal.create(v, r.dataType) - case None => Literal.default(NullType) - } + bindings.getOrElse(r.exprId, Literal.default(NullType)) } - Option(rewrittenExpr.eval()) + + tryEvalExpr(rewrittenExpr) } /** * Statically evaluate an expression containing one or more aggregates on an empty input. */ - private def evalAggOnZeroTups(expr: Expression) : Option[Any] = { + private def evalAggOnZeroTups(expr: Expression) : Expression = { // AggregateExpressions are Unevaluable, so we need to replace all aggregates // in the expression with the value they would return for zero input tuples. // Also replace attribute refs (for example, for grouping columns) with NULL. @@ -344,7 +361,8 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { case _: AttributeReference => Literal.default(NullType) } - Option(rewrittenExpr.eval()) + + tryEvalExpr(rewrittenExpr) } /** @@ -354,19 +372,33 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { * [[org.apache.spark.sql.catalyst.analysis.CheckAnalysis]]. If the checks in * CheckAnalysis become less restrictive, this method will need to change. */ - private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Any] = { + private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Expression] = { // Inputs to this method will start with a chain of zero or more SubqueryAlias // and Project operators, followed by an optional Filter, followed by an // Aggregate. Traverse the operators recursively. - def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Any]] = lp match { + def evalPlan(lp : LogicalPlan) : Map[ExprId, Expression] = lp match { case SubqueryAlias(_, child) => evalPlan(child) case Filter(condition, child) => val bindings = evalPlan(child) - if (bindings.isEmpty) bindings - else { - val exprResult = evalExpr(condition, bindings).getOrElse(false) - .asInstanceOf[Boolean] - if (exprResult) bindings else Map.empty + if (bindings.isEmpty) { + bindings + } else { + val bindCondition = bindingExpr(condition, bindings) + + if (!bindCondition.foldable) { + // We can't evaluate the condition. Evaluate it in query runtime. + bindings.map { case (id, expr) => + val newExpr = If(bindCondition, expr, Literal.create(null, expr.dataType)) + (id, newExpr) + } + } else { + // The bound condition can be evaluated. + bindCondition.eval() match { + // For filter condition, null is the same as false. + case null | false => Map.empty + case true => bindings + } + } } case Project(projectList, child) => @@ -374,7 +406,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { if (bindings.isEmpty) { bindings } else { - projectList.map(ne => (ne.exprId, evalExpr(ne, bindings))).toMap + projectList.map(ne => (ne.exprId, bindingExpr(ne, bindings))).toMap } case Aggregate(_, aggExprs, _) => @@ -382,8 +414,9 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { // for joining with the outer query block. Fill those expressions in with // nulls and statically evaluate the remainder. aggExprs.map { - case ref: AttributeReference => (ref.exprId, None) - case alias @ Alias(_: AttributeReference, _) => (alias.exprId, None) + case ref: AttributeReference => (ref.exprId, Literal.create(null, ref.dataType)) + case alias @ Alias(_: AttributeReference, _) => + (alias.exprId, Literal.create(null, alias.dataType)) case ne => (ne.exprId, evalAggOnZeroTups(ne)) }.toMap @@ -394,7 +427,10 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { val resultMap = evalPlan(plan) // By convention, the scalar subquery result is the leftmost field. - resultMap.getOrElse(plan.output.head.exprId, None) + resultMap.get(plan.output.head.exprId) match { + case Some(Literal(null, _)) | None => None + case o => o + } } /** @@ -473,7 +509,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { currentChild.output :+ Alias( If(IsNull(alwaysTrueRef), - Literal.create(resultWithZeroTups.get, origOutput.dataType), + resultWithZeroTups.get, aggValRef), origOutput.name)(exprId = origOutput.exprId), Join(currentChild, Project(query.output :+ alwaysTrueExpr, query), @@ -494,11 +530,11 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { case op => sys.error(s"Unexpected operator $op in corelated subquery") } - // CASE WHEN alwayTrue IS NULL THEN resultOnZeroTups + // CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups // WHEN NOT (original HAVING clause expr) THEN CAST(null AS ) // ELSE (aggregate value) END AS (original column name) val caseExpr = Alias(CaseWhen(Seq( - (IsNull(alwaysTrueRef), Literal.create(resultWithZeroTups.get, origOutput.dataType)), + (IsNull(alwaysTrueRef), resultWithZeroTups.get), (Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))), aggValRef), origOutput.name)(exprId = origOutput.exprId) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index b2c38684071dc..4ec85b0ac6d2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -1384,4 +1384,231 @@ class SubquerySuite extends QueryTest with SharedSQLContext { assert(subqueryExecs.forall(_.name.startsWith("scalar-subquery#")), "SubqueryExec name should start with scalar-subquery#") } + + test("SPARK-28441: COUNT bug in WHERE clause (Filter) with PythonUDF") { + import IntegratedUDFTestUtils._ + + val pythonTestUDF = TestPythonUDF(name = "udf") + registerTestUDF(pythonTestUDF, spark) + + // Case 1: Canonical example of the COUNT bug + checkAnswer( + sql("SELECT l.a FROM l WHERE (SELECT udf(count(*)) FROM r WHERE l.a = r.c) < l.a"), + Row(1) :: Row(1) :: Row(3) :: Row(6) :: Nil) + // Case 2: count(*) = 0; could be rewritten to NOT EXISTS but currently uses + // a rewrite that is vulnerable to the COUNT bug + checkAnswer( + sql("SELECT l.a FROM l WHERE (SELECT udf(count(*)) FROM r WHERE l.a = r.c) = 0"), + Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) + // Case 3: COUNT bug without a COUNT aggregate + checkAnswer( + sql("SELECT l.a FROM l WHERE (SELECT udf(sum(r.d)) is null FROM r WHERE l.a = r.c)"), + Row(1) :: Row(1) ::Row(null) :: Row(null) :: Row(6) :: Nil) + } + + test("SPARK-28441: COUNT bug in SELECT clause (Project) with PythonUDF") { + import IntegratedUDFTestUtils._ + + val pythonTestUDF = TestPythonUDF(name = "udf") + registerTestUDF(pythonTestUDF, spark) + + checkAnswer( + sql("SELECT a, (SELECT udf(count(*)) FROM r WHERE l.a = r.c) AS cnt FROM l"), + Row(1, 0) :: Row(1, 0) :: Row(2, 2) :: Row(2, 2) :: Row(3, 1) :: Row(null, 0) + :: Row(null, 0) :: Row(6, 1) :: Nil) + } + + test("SPARK-28441: COUNT bug in HAVING clause (Filter) with PythonUDF") { + import IntegratedUDFTestUtils._ + + val pythonTestUDF = TestPythonUDF(name = "udf") + registerTestUDF(pythonTestUDF, spark) + + checkAnswer( + sql(""" + |SELECT + | l.a AS grp_a + |FROM l GROUP BY l.a + |HAVING + | ( + | SELECT udf(count(*)) FROM r WHERE grp_a = r.c + | ) = 0 + |ORDER BY grp_a""".stripMargin), + Row(null) :: Row(1) :: Nil) + } + + test("SPARK-28441: COUNT bug in Aggregate with PythonUDF") { + import IntegratedUDFTestUtils._ + + val pythonTestUDF = TestPythonUDF(name = "udf") + registerTestUDF(pythonTestUDF, spark) + + checkAnswer( + sql(""" + |SELECT + | l.a AS aval, + | sum( + | ( + | SELECT udf(count(*)) FROM r WHERE l.a = r.c + | ) + | ) AS cnt + |FROM l GROUP BY l.a ORDER BY aval""".stripMargin), + Row(null, 0) :: Row(1, 0) :: Row(2, 4) :: Row(3, 1) :: Row(6, 1) :: Nil) + } + + test("SPARK-28441: COUNT bug negative examples with PythonUDF") { + import IntegratedUDFTestUtils._ + + val pythonTestUDF = TestPythonUDF(name = "udf") + registerTestUDF(pythonTestUDF, spark) + + // Case 1: Potential COUNT bug case that was working correctly prior to the fix + checkAnswer( + sql("SELECT l.a FROM l WHERE (SELECT udf(sum(r.d)) FROM r WHERE l.a = r.c) is null"), + Row(1) :: Row(1) :: Row(null) :: Row(null) :: Row(6) :: Nil) + // Case 2: COUNT aggregate but no COUNT bug due to > 0 test. + checkAnswer( + sql("SELECT l.a FROM l WHERE (SELECT udf(count(*)) FROM r WHERE l.a = r.c) > 0"), + Row(2) :: Row(2) :: Row(3) :: Row(6) :: Nil) + // Case 3: COUNT inside aggregate expression but no COUNT bug. + checkAnswer( + sql(""" + |SELECT + | l.a + |FROM l + |WHERE + | ( + | SELECT udf(count(*)) + udf(sum(r.d)) + | FROM r WHERE l.a = r.c + | ) = 0""".stripMargin), + Nil) + } + + test("SPARK-28441: COUNT bug in nested subquery with PythonUDF") { + import IntegratedUDFTestUtils._ + + val pythonTestUDF = TestPythonUDF(name = "udf") + registerTestUDF(pythonTestUDF, spark) + + checkAnswer( + sql(""" + |SELECT l.a FROM l + |WHERE ( + | SELECT cntPlusOne + 1 AS cntPlusTwo FROM ( + | SELECT cnt + 1 AS cntPlusOne FROM ( + | SELECT udf(sum(r.c)) s, udf(count(*)) cnt FROM r WHERE l.a = r.c + | HAVING cnt = 0 + | ) + | ) + |) = 2""".stripMargin), + Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) + } + + test("SPARK-28441: COUNT bug with nasty predicate expr with PythonUDF") { + import IntegratedUDFTestUtils._ + + val pythonTestUDF = TestPythonUDF(name = "udf") + registerTestUDF(pythonTestUDF, spark) + + checkAnswer( + sql(""" + |SELECT + | l.a + |FROM l WHERE + | ( + | SELECT CASE WHEN udf(count(*)) = 1 THEN null ELSE udf(count(*)) END AS cnt + | FROM r WHERE l.a = r.c + | ) = 0""".stripMargin), + Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) + } + + test("SPARK-28441: COUNT bug with attribute ref in subquery input and output with PythonUDF") { + import IntegratedUDFTestUtils._ + + val pythonTestUDF = TestPythonUDF(name = "udf") + registerTestUDF(pythonTestUDF, spark) + + checkAnswer( + sql( + """ + |SELECT + | l.b, + | ( + | SELECT (r.c + udf(count(*))) is null + | FROM r + | WHERE l.a = r.c GROUP BY r.c + | ) + |FROM l + """.stripMargin), + Row(1.0, false) :: Row(1.0, false) :: Row(2.0, true) :: Row(2.0, true) :: + Row(3.0, false) :: Row(5.0, true) :: Row(null, false) :: Row(null, true) :: Nil) + } + + test("SPARK-28441: COUNT bug with non-foldable expression") { + // Case 1: Canonical example of the COUNT bug + checkAnswer( + sql("SELECT l.a FROM l WHERE (SELECT count(*) + cast(rand() as int) FROM r " + + "WHERE l.a = r.c) < l.a"), + Row(1) :: Row(1) :: Row(3) :: Row(6) :: Nil) + // Case 2: count(*) = 0; could be rewritten to NOT EXISTS but currently uses + // a rewrite that is vulnerable to the COUNT bug + checkAnswer( + sql("SELECT l.a FROM l WHERE (SELECT count(*) + cast(rand() as int) FROM r " + + "WHERE l.a = r.c) = 0"), + Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) + // Case 3: COUNT bug without a COUNT aggregate + checkAnswer( + sql("SELECT l.a FROM l WHERE (SELECT sum(r.d) is null from r " + + "WHERE l.a = r.c)"), + Row(1) :: Row(1) ::Row(null) :: Row(null) :: Row(6) :: Nil) + } + + test("SPARK-28441: COUNT bug in nested subquery with non-foldable expr") { + checkAnswer( + sql(""" + |SELECT l.a FROM l + |WHERE ( + | SELECT cntPlusOne + 1 AS cntPlusTwo FROM ( + | SELECT cnt + 1 AS cntPlusOne FROM ( + | SELECT sum(r.c) s, (count(*) + cast(rand() as int)) cnt FROM r + | WHERE l.a = r.c HAVING cnt = 0 + | ) + | ) + |) = 2""".stripMargin), + Row(1) :: Row(1) :: Row(null) :: Row(null) :: Nil) + } + + test("SPARK-28441: COUNT bug with non-foldable expression in Filter condition") { + val df = sql(""" + |SELECT + | l.a + |FROM l WHERE + | ( + | SELECT cntPlusOne + 1 as cntPlusTwo FROM + | ( + | SELECT cnt + 1 as cntPlusOne FROM + | ( + | SELECT sum(r.c) s, count(*) cnt FROM r WHERE l.a = r.c HAVING cnt > 0 + | ) + | ) + | ) = 2""".stripMargin) + val df2 = sql(""" + |SELECT + | l.a + |FROM l WHERE + | ( + | SELECT cntPlusOne + 1 AS cntPlusTwo + | FROM + | ( + | SELECT cnt + 1 AS cntPlusOne + | FROM + | ( + | SELECT sum(r.c) s, count(*) cnt FROM r + | WHERE l.a = r.c HAVING (cnt + cast(rand() as int)) > 0 + | ) + | ) + | ) = 2""".stripMargin) + checkAnswer(df, df2) + checkAnswer(df, Nil) + } }