-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-28441][SQL][Python] Fix error when non-foldable expression is used in correlated scalar subquery #25204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
725304c
7972d7c
33441a3
110a39e
a7803f5
0158d85
2dd29c1
9aea844
1f6b717
d7d023d
fd29677
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,7 @@ | |
|
|
||
| package org.apache.spark.sql.catalyst.optimizer | ||
|
|
||
| import scala.collection.mutable.ArrayBuffer | ||
| import scala.collection.mutable.{ArrayBuffer, HashMap} | ||
|
|
||
| import org.apache.spark.sql.AnalysisException | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
|
|
@@ -316,25 +316,46 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { | |
| newExpression.asInstanceOf[E] | ||
| } | ||
|
|
||
| private def removeAlias(expr: Expression): Expression = expr match { | ||
| case Alias(c, _) => removeAlias(c) | ||
| case _ => expr | ||
| } | ||
|
|
||
| /** | ||
| * Statically evaluate an expression containing zero or more placeholders, given a set | ||
| * of bindings for placeholder values. | ||
| * of bindings for placeholder values, if the expression is evaluable. If it is not, | ||
| * bind statically evaluated expression results to an expression. | ||
| */ | ||
| private def evalExpr(expr: Expression, bindings: Map[ExprId, Option[Any]]) : Option[Any] = { | ||
| private def bindingExpr( | ||
| expr: Expression, | ||
| bindings: Map[ExprId, Option[Expression]]): Option[Expression] = { | ||
| val rewrittenExpr = expr transform { | ||
| case r: AttributeReference => | ||
| bindings(r.exprId) match { | ||
| case Some(v) => Literal.create(v, r.dataType) | ||
| case Some(v) => v | ||
| case None => Literal.default(NullType) | ||
| } | ||
| } | ||
| Option(rewrittenExpr.eval()) | ||
|
|
||
| // Removes Alias over given expression, because Alias is not foldable. | ||
| if (!removeAlias(rewrittenExpr).foldable) { | ||
| // SPARK-28441: Some expressions, like PythonUDF, can't be statically evaluated. | ||
| // Needs to evaluate them on query runtime. | ||
| Some(rewrittenExpr) | ||
| } else { | ||
| val exprVal = rewrittenExpr.eval() | ||
| if (exprVal == null) { | ||
| None | ||
|
||
| } else { | ||
| Some(Literal.create(exprVal, expr.dataType)) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Statically evaluate an expression containing one or more aggregates on an empty input. | ||
| */ | ||
| private def evalAggOnZeroTups(expr: Expression) : Option[Any] = { | ||
| private def evalAggOnZeroTups(expr: Expression) : Option[Expression] = { | ||
| // AggregateExpressions are Unevaluable, so we need to replace all aggregates | ||
| // in the expression with the value they would return for zero input tuples. | ||
| // Also replace attribute refs (for example, for grouping columns) with NULL. | ||
|
|
@@ -344,7 +365,20 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { | |
|
|
||
| case _: AttributeReference => Literal.default(NullType) | ||
| } | ||
| Option(rewrittenExpr.eval()) | ||
|
|
||
| // Removes Alias over given expression, because Alias is not foldable. | ||
| if (!removeAlias(rewrittenExpr).foldable) { | ||
|
||
| // SPARK-28441: Some expressions, like PythonUDF, can't be statically evaluated. | ||
| // Needs to evaluate them on query runtime. | ||
| Some(rewrittenExpr) | ||
| } else { | ||
| val exprVal = rewrittenExpr.eval() | ||
| if (exprVal == null) { | ||
| None | ||
| } else { | ||
| Some(Literal.create(exprVal, expr.dataType)) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -354,27 +388,39 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { | |
| * [[org.apache.spark.sql.catalyst.analysis.CheckAnalysis]]. If the checks in | ||
| * CheckAnalysis become less restrictive, this method will need to change. | ||
| */ | ||
| private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Any] = { | ||
| private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Expression] = { | ||
| // Inputs to this method will start with a chain of zero or more SubqueryAlias | ||
| // and Project operators, followed by an optional Filter, followed by an | ||
| // Aggregate. Traverse the operators recursively. | ||
| def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Any]] = lp match { | ||
| def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Expression]] = lp match { | ||
| case SubqueryAlias(_, child) => evalPlan(child) | ||
| case Filter(condition, child) => | ||
| val bindings = evalPlan(child) | ||
| if (bindings.isEmpty) bindings | ||
| else { | ||
| val exprResult = evalExpr(condition, bindings).getOrElse(false) | ||
| .asInstanceOf[Boolean] | ||
| if (exprResult) bindings else Map.empty | ||
| if (bindings.isEmpty) { | ||
| bindings | ||
| } else { | ||
| val bindExpr = bindingExpr(condition, bindings) | ||
|
||
| .getOrElse(Literal.create(false, BooleanType)) | ||
|
||
|
|
||
| if (!bindExpr.foldable) { | ||
| // We can't evaluate the condition. Evaluate it in query runtime. | ||
| bindings.map { case (id, expr) => | ||
| val newExpr = expr.map(e => If(bindExpr, e, Literal.create(null, e.dataType))) | ||
| (id, newExpr) | ||
| } | ||
| } else { | ||
| // The bound condition can be evaluated. | ||
| val exprResult = bindExpr.eval().asInstanceOf[Boolean] | ||
| if (exprResult) bindings else Map.empty | ||
| } | ||
| } | ||
|
|
||
| case Project(projectList, child) => | ||
| val bindings = evalPlan(child) | ||
| if (bindings.isEmpty) { | ||
| bindings | ||
| } else { | ||
| projectList.map(ne => (ne.exprId, evalExpr(ne, bindings))).toMap | ||
| projectList.map(ne => (ne.exprId, bindingExpr(ne, bindings))).toMap | ||
| } | ||
|
|
||
| case Aggregate(_, aggExprs, _) => | ||
|
|
@@ -432,6 +478,18 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { | |
| sys.error("This line should be unreachable") | ||
| } | ||
|
|
||
| /** | ||
| * This replaces original expression id used in attributes and aliases in expression. | ||
| */ | ||
| private def replaceOldExprId( | ||
|
||
| orgExprId: ExprId, | ||
|
||
| newExprId: ExprId): PartialFunction[Expression, Expression] = { | ||
| case a: AttributeReference if a.exprId == orgExprId => | ||
| a.withExprId(newExprId) | ||
| case a: Alias if a.exprId == orgExprId => | ||
| Alias(child = a.child, name = a.name)(exprId = newExprId) | ||
| } | ||
|
|
||
| // Name of generated column used in rewrite below | ||
| val ALWAYS_TRUE_COLNAME = "alwaysTrue" | ||
|
|
||
|
|
@@ -465,19 +523,34 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { | |
| val alwaysTrueRef = AttributeReference(ALWAYS_TRUE_COLNAME, | ||
| BooleanType)(exprId = alwaysTrueExprId) | ||
|
|
||
| val aggValRef = query.output.head | ||
|
|
||
| if (havingNode.isEmpty) { | ||
| // CASE 2: Subquery with no HAVING clause | ||
| // The added Alias column uses expr id of original output. | ||
|
|
||
| // We replace original expression id with a new one. The added Alias column | ||
| // must use expr id of original output. If we don't replace old expr id in the | ||
| // query, the added Project in potential Project-Filter-Project can be removed | ||
| // by removeProjectBeforeFilter in ColumnPruning. | ||
|
||
| val newExprId = NamedExpression.newExprId | ||
| val newQuery = | ||
| query.transformExpressions(replaceOldExprId(origOutput.exprId, newExprId)) | ||
|
|
||
| val result = resultWithZeroTups.get | ||
| .transform(replaceOldExprId(origOutput.exprId, newExprId)) | ||
|
|
||
| val newCondition = | ||
| conditions.map(_.transform(replaceOldExprId(origOutput.exprId, newExprId))) | ||
|
|
||
| val newExpr = Alias( | ||
| If(IsNull(alwaysTrueRef), | ||
| result, | ||
| newQuery.output.head), origOutput.name)(exprId = origOutput.exprId) | ||
|
|
||
| Project( | ||
| currentChild.output :+ | ||
| Alias( | ||
| If(IsNull(alwaysTrueRef), | ||
| Literal.create(resultWithZeroTups.get, origOutput.dataType), | ||
| aggValRef), origOutput.name)(exprId = origOutput.exprId), | ||
| currentChild.output :+ newExpr, | ||
| Join(currentChild, | ||
| Project(query.output :+ alwaysTrueExpr, query), | ||
| LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) | ||
| Project(newQuery.output :+ alwaysTrueExpr, newQuery), | ||
| LeftOuter, newCondition.reduceOption(And), JoinHint.NONE)) | ||
|
|
||
| } else { | ||
| // CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join. | ||
|
|
@@ -494,21 +567,34 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { | |
| case op => sys.error(s"Unexpected operator $op in corelated subquery") | ||
| } | ||
|
|
||
| // CASE WHEN alwayTrue IS NULL THEN resultOnZeroTups | ||
| // We replace original expression id with a new one. The added Alias column | ||
| // must use expr id of original output. If we don't replace old expr id in the | ||
| // query, the added Project in potential Project-Filter-Project can be removed | ||
| // by removeProjectBeforeFilter in ColumnPruning. | ||
| val newExprId = NamedExpression.newExprId | ||
| val newQuery = | ||
| subqueryRoot.transformExpressions(replaceOldExprId(origOutput.exprId, newExprId)) | ||
|
|
||
| val result = resultWithZeroTups.get | ||
| .transform(replaceOldExprId(origOutput.exprId, newExprId)) | ||
|
|
||
| val newCondition = | ||
| conditions.map(_.transform(replaceOldExprId(origOutput.exprId, newExprId))) | ||
|
|
||
| // CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups | ||
| // WHEN NOT (original HAVING clause expr) THEN CAST(null AS <type of aggVal>) | ||
| // ELSE (aggregate value) END AS (original column name) | ||
| val caseExpr = Alias(CaseWhen(Seq( | ||
| (IsNull(alwaysTrueRef), Literal.create(resultWithZeroTups.get, origOutput.dataType)), | ||
| (Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))), | ||
| aggValRef), | ||
| (IsNull(alwaysTrueRef), result), | ||
| (Not(havingNode.get.condition), Literal.create(null, newQuery.output.head.dataType))), | ||
| newQuery.output.head), | ||
| origOutput.name)(exprId = origOutput.exprId) | ||
|
|
||
| Project( | ||
| currentChild.output :+ caseExpr, | ||
| Join(currentChild, | ||
| Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot), | ||
| LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) | ||
|
|
||
| Project(newQuery.output :+ alwaysTrueExpr, newQuery), | ||
| LeftOuter, newCondition.reduceOption(And), JoinHint.NONE)) | ||
| } | ||
| } | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if there are several aliases? Shall we use
CleanupAliasesinstead?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We track expressions from aggregate expressions as root. I think aliases should be continuous on top. Using
CleanupAliasesis also good, at least we don't need adding new method.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, sorry, this is recursive too, but I think it is good to avoid a new method. Thanks.