Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.optimizer

import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.{ArrayBuffer, HashMap}

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -316,25 +316,46 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
newExpression.asInstanceOf[E]
}

private def removeAlias(expr: Expression): Expression = expr match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if there are several aliases? Shall we use CleanupAliases instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We track expressions from aggregate expressions as root. I think aliases should be continuous on top. Using CleanupAliases is also good, at least we don't need adding new method.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, sorry, this is recursive too, but I think it is good to avoid a new method. Thanks.

case Alias(c, _) => removeAlias(c)
case _ => expr
}

/**
* Statically evaluate an expression containing zero or more placeholders, given a set
* of bindings for placeholder values.
* of bindings for placeholder values, if the expression is evaluable. If it is not,
* bind statically evaluated expression results to an expression.
*/
private def evalExpr(expr: Expression, bindings: Map[ExprId, Option[Any]]) : Option[Any] = {
private def bindingExpr(
expr: Expression,
bindings: Map[ExprId, Option[Expression]]): Option[Expression] = {
val rewrittenExpr = expr transform {
case r: AttributeReference =>
bindings(r.exprId) match {
case Some(v) => Literal.create(v, r.dataType)
case Some(v) => v
case None => Literal.default(NullType)
}
}
Option(rewrittenExpr.eval())

// Removes Alias over given expression, because Alias is not foldable.
if (!removeAlias(rewrittenExpr).foldable) {
// SPARK-28441: Some expressions, like PythonUDF, can't be statically evaluated.
// Needs to evaluate them on query runtime.
Some(rewrittenExpr)
} else {
val exprVal = rewrittenExpr.eval()
if (exprVal == null) {
None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why we need to return None here instead of a null literal?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it uses None to make checking bindings easier.

In other way, to use null literal, Option[Expression] can be changed to Expression in methods like evalSubqueryOnZeroTups, evalPlan. Then we check bindings by literal instead of None. Good thing is we can write Literal.create(rewrittenExpr.eval(), expr.dataType), instead of checking null. Looks like just a choice problem.

} else {
Some(Literal.create(exprVal, expr.dataType))
}
}
}

/**
* Statically evaluate an expression containing one or more aggregates on an empty input.
*/
private def evalAggOnZeroTups(expr: Expression) : Option[Any] = {
private def evalAggOnZeroTups(expr: Expression) : Option[Expression] = {
// AggregateExpressions are Unevaluable, so we need to replace all aggregates
// in the expression with the value they would return for zero input tuples.
// Also replace attribute refs (for example, for grouping columns) with NULL.
Expand All @@ -344,7 +365,20 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {

case _: AttributeReference => Literal.default(NullType)
}
Option(rewrittenExpr.eval())

// Removes Alias over given expression, because Alias is not foldable.
if (!removeAlias(rewrittenExpr).foldable) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like we can move the following code into a common method?

// SPARK-28441: Some expressions, like PythonUDF, can't be statically evaluated.
// Needs to evaluate them on query runtime.
Some(rewrittenExpr)
} else {
val exprVal = rewrittenExpr.eval()
if (exprVal == null) {
None
} else {
Some(Literal.create(exprVal, expr.dataType))
}
}
}

/**
Expand All @@ -354,27 +388,39 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
* [[org.apache.spark.sql.catalyst.analysis.CheckAnalysis]]. If the checks in
* CheckAnalysis become less restrictive, this method will need to change.
*/
private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Any] = {
private def evalSubqueryOnZeroTups(plan: LogicalPlan) : Option[Expression] = {
// Inputs to this method will start with a chain of zero or more SubqueryAlias
// and Project operators, followed by an optional Filter, followed by an
// Aggregate. Traverse the operators recursively.
def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Any]] = lp match {
def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Expression]] = lp match {
case SubqueryAlias(_, child) => evalPlan(child)
case Filter(condition, child) =>
val bindings = evalPlan(child)
if (bindings.isEmpty) bindings
else {
val exprResult = evalExpr(condition, bindings).getOrElse(false)
.asInstanceOf[Boolean]
if (exprResult) bindings else Map.empty
if (bindings.isEmpty) {
bindings
} else {
val bindExpr = bindingExpr(condition, bindings)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: bindCondition looks better.

.getOrElse(Literal.create(false, BooleanType))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For filter condition, null is the same as false. This is one place that makes me think bindingExpr should return Option[Expression].

If this is the only place, I think it's simpler to always return expression, and handle null especially here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I may try this way tomorrow.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this works. Looks good as it's simple.


if (!bindExpr.foldable) {
// We can't evaluate the condition. Evaluate it in query runtime.
bindings.map { case (id, expr) =>
val newExpr = expr.map(e => If(bindExpr, e, Literal.create(null, e.dataType)))
(id, newExpr)
}
} else {
// The bound condition can be evaluated.
val exprResult = bindExpr.eval().asInstanceOf[Boolean]
if (exprResult) bindings else Map.empty
}
}

case Project(projectList, child) =>
val bindings = evalPlan(child)
if (bindings.isEmpty) {
bindings
} else {
projectList.map(ne => (ne.exprId, evalExpr(ne, bindings))).toMap
projectList.map(ne => (ne.exprId, bindingExpr(ne, bindings))).toMap
}

case Aggregate(_, aggExprs, _) =>
Expand Down Expand Up @@ -432,6 +478,18 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
sys.error("This line should be unreachable")
}

/**
* This replaces original expression id used in attributes and aliases in expression.
*/
private def replaceOldExprId(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can remove this.

orgExprId: ExprId,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

orgExprId -> oldExprId ?

newExprId: ExprId): PartialFunction[Expression, Expression] = {
case a: AttributeReference if a.exprId == orgExprId =>
a.withExprId(newExprId)
case a: Alias if a.exprId == orgExprId =>
Alias(child = a.child, name = a.name)(exprId = newExprId)
}

// Name of generated column used in rewrite below
val ALWAYS_TRUE_COLNAME = "alwaysTrue"

Expand Down Expand Up @@ -465,19 +523,34 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
val alwaysTrueRef = AttributeReference(ALWAYS_TRUE_COLNAME,
BooleanType)(exprId = alwaysTrueExprId)

val aggValRef = query.output.head

if (havingNode.isEmpty) {
// CASE 2: Subquery with no HAVING clause
// The added Alias column uses expr id of original output.

// We replace original expression id with a new one. The added Alias column
// must use expr id of original output. If we don't replace old expr id in the
// query, the added Project in potential Project-Filter-Project can be removed
// by removeProjectBeforeFilter in ColumnPruning.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we fix removeProjectBeforeFilter to only remove attribute-only projects?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth trying, right now not sure if any other thing will be affected.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried locally. Added subquery tests are passed. We can see if Jenkins passes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Seems fine. Jenkins passes.

val newExprId = NamedExpression.newExprId
val newQuery =
query.transformExpressions(replaceOldExprId(origOutput.exprId, newExprId))

val result = resultWithZeroTups.get
.transform(replaceOldExprId(origOutput.exprId, newExprId))

val newCondition =
conditions.map(_.transform(replaceOldExprId(origOutput.exprId, newExprId)))

val newExpr = Alias(
If(IsNull(alwaysTrueRef),
result,
newQuery.output.head), origOutput.name)(exprId = origOutput.exprId)

Project(
currentChild.output :+
Alias(
If(IsNull(alwaysTrueRef),
Literal.create(resultWithZeroTups.get, origOutput.dataType),
aggValRef), origOutput.name)(exprId = origOutput.exprId),
currentChild.output :+ newExpr,
Join(currentChild,
Project(query.output :+ alwaysTrueExpr, query),
LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
Project(newQuery.output :+ alwaysTrueExpr, newQuery),
LeftOuter, newCondition.reduceOption(And), JoinHint.NONE))

} else {
// CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join.
Expand All @@ -494,21 +567,34 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
case op => sys.error(s"Unexpected operator $op in corelated subquery")
}

// CASE WHEN alwayTrue IS NULL THEN resultOnZeroTups
// We replace original expression id with a new one. The added Alias column
// must use expr id of original output. If we don't replace old expr id in the
// query, the added Project in potential Project-Filter-Project can be removed
// by removeProjectBeforeFilter in ColumnPruning.
val newExprId = NamedExpression.newExprId
val newQuery =
subqueryRoot.transformExpressions(replaceOldExprId(origOutput.exprId, newExprId))

val result = resultWithZeroTups.get
.transform(replaceOldExprId(origOutput.exprId, newExprId))

val newCondition =
conditions.map(_.transform(replaceOldExprId(origOutput.exprId, newExprId)))

// CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups
// WHEN NOT (original HAVING clause expr) THEN CAST(null AS <type of aggVal>)
// ELSE (aggregate value) END AS (original column name)
val caseExpr = Alias(CaseWhen(Seq(
(IsNull(alwaysTrueRef), Literal.create(resultWithZeroTups.get, origOutput.dataType)),
(Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))),
aggValRef),
(IsNull(alwaysTrueRef), result),
(Not(havingNode.get.condition), Literal.create(null, newQuery.output.head.dataType))),
newQuery.output.head),
origOutput.name)(exprId = origOutput.exprId)

Project(
currentChild.output :+ caseExpr,
Join(currentChild,
Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot),
LeftOuter, conditions.reduceOption(And), JoinHint.NONE))

Project(newQuery.output :+ alwaysTrueExpr, newQuery),
LeftOuter, newCondition.reduceOption(And), JoinHint.NONE))
}
}
}
Expand Down
Loading