-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-42500][SQL] ConstantPropagation support more cases #40268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
dc04923
0bb53d8
1387247
cabb083
0244034
b9f3fbb
9acb27e
ecd650a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |
| package org.apache.spark.sql.catalyst.optimizer | ||
|
|
||
| import scala.collection.immutable.HashSet | ||
| import scala.collection.mutable | ||
| import scala.collection.mutable.{ArrayBuffer, Stack} | ||
| import scala.util.control.NonFatal | ||
|
|
||
|
|
@@ -112,16 +113,13 @@ object ConstantFolding extends Rule[LogicalPlan] { | |
| object ConstantPropagation extends Rule[LogicalPlan] { | ||
| def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( | ||
| _.containsAllPatterns(LITERAL, FILTER), ruleId) { | ||
| case f: Filter => | ||
| val (newCondition, _) = traverse(f.condition, replaceChildren = true, nullIsFalse = true) | ||
| if (newCondition.isDefined) { | ||
| f.copy(condition = newCondition.get) | ||
| } else { | ||
| f | ||
| } | ||
| case f: Filter => f.mapExpressions(traverse(_, replaceChildren = true, nullIsFalse = true)._1) | ||
| } | ||
|
|
||
| type EqualityPredicates = Seq[((AttributeReference, Literal), BinaryComparison)] | ||
| // The keys are always canonicalized `AttributeReference`s, but it is easier to use `Expression` | ||
| // type keys here instead of casting `AttributeReference.canonicalized` to `AttributeReference` at | ||
| // the calling sites. | ||
| type EqualityPredicates = mutable.Map[Expression, (Literal, BinaryComparison)] | ||
|
|
||
| /** | ||
| * Traverse a condition as a tree and replace attributes with constant values. | ||
|
|
@@ -138,56 +136,52 @@ object ConstantPropagation extends Rule[LogicalPlan] { | |
| * case of `WHERE e`, null result of expression `e` means the same as if it | ||
| * resulted false | ||
| * @return A tuple including: | ||
| * 1. Option[Expression]: optional changed condition after traversal | ||
| * 1. Expression: optional changed condition after traversal | ||
|
||
| * 2. EqualityPredicates: propagated mapping of attribute => constant | ||
| */ | ||
| private def traverse(condition: Expression, replaceChildren: Boolean, nullIsFalse: Boolean) | ||
| : (Option[Expression], EqualityPredicates) = | ||
| private def traverse( | ||
| condition: Expression, | ||
| replaceChildren: Boolean, | ||
| nullIsFalse: Boolean): (Expression, EqualityPredicates) = | ||
| condition match { | ||
| case e @ EqualTo(left: AttributeReference, right: Literal) | ||
| if safeToReplace(left, nullIsFalse) => | ||
| (None, Seq(((left, right), e))) | ||
| e -> mutable.Map(left.canonicalized -> (right, e)) | ||
| case e @ EqualTo(left: Literal, right: AttributeReference) | ||
| if safeToReplace(right, nullIsFalse) => | ||
| (None, Seq(((right, left), e))) | ||
| e -> mutable.Map(right.canonicalized -> (left, e)) | ||
| case e @ EqualNullSafe(left: AttributeReference, right: Literal) | ||
| if safeToReplace(left, nullIsFalse) => | ||
| (None, Seq(((left, right), e))) | ||
| e -> mutable.Map(left.canonicalized -> (right, e)) | ||
| case e @ EqualNullSafe(left: Literal, right: AttributeReference) | ||
| if safeToReplace(right, nullIsFalse) => | ||
| (None, Seq(((right, left), e))) | ||
| case a: And => | ||
| val (newLeft, equalityPredicatesLeft) = | ||
| traverse(a.left, replaceChildren = false, nullIsFalse) | ||
| e -> mutable.Map(right.canonicalized -> (left, e)) | ||
| case a @ And(left, right) => | ||
| val (newLeft, equalityPredicates) = | ||
| traverse(left, replaceChildren = false, nullIsFalse) | ||
| val (newRight, equalityPredicatesRight) = | ||
| traverse(a.right, replaceChildren = false, nullIsFalse) | ||
| val equalityPredicates = equalityPredicatesLeft ++ equalityPredicatesRight | ||
| val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) { | ||
| Some(And(replaceConstants(newLeft.getOrElse(a.left), equalityPredicates), | ||
| replaceConstants(newRight.getOrElse(a.right), equalityPredicates))) | ||
| traverse(right, replaceChildren = false, nullIsFalse) | ||
| // We could recognize when conflicting constants are coming from the left and right sides | ||
| // and immediately shortcut the `And` expression to `Literal.FalseLiteral`, but that case is | ||
| // not so common and actually it is the job of `ConstantFolding` and `BooleanSimplification` | ||
| // rules to deal with those optimizations. | ||
| equalityPredicates ++= equalityPredicatesRight | ||
| val newAnd = a.withNewChildren(if (equalityPredicates.nonEmpty && replaceChildren) { | ||
| val replacedNewLeft = replaceConstants(newLeft, equalityPredicates) | ||
| val replacedNewRight = replaceConstants(newRight, equalityPredicates) | ||
| Seq(replacedNewLeft, replacedNewRight) | ||
| } else { | ||
| if (newLeft.isDefined || newRight.isDefined) { | ||
| Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right))) | ||
| } else { | ||
| None | ||
| } | ||
| } | ||
| (newSelf, equalityPredicates) | ||
| Seq(newLeft, newRight) | ||
| }) | ||
| newAnd -> equalityPredicates | ||
| case o: Or => | ||
| // Ignore the EqualityPredicates from children since they are only propagated through And. | ||
| val (newLeft, _) = traverse(o.left, replaceChildren = true, nullIsFalse) | ||
| val (newRight, _) = traverse(o.right, replaceChildren = true, nullIsFalse) | ||
| val newSelf = if (newLeft.isDefined || newRight.isDefined) { | ||
| Some(Or(left = newLeft.getOrElse(o.left), right = newRight.getOrElse((o.right)))) | ||
| } else { | ||
| None | ||
| } | ||
| (newSelf, Seq.empty) | ||
| o.mapChildren(traverse(_, replaceChildren = true, nullIsFalse)._1) -> mutable.Map.empty | ||
wangyum marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| case n: Not => | ||
| // Ignore the EqualityPredicates from children since they are only propagated through And. | ||
| val (newChild, _) = traverse(n.child, replaceChildren = true, nullIsFalse = false) | ||
| (newChild.map(Not), Seq.empty) | ||
| case _ => (None, Seq.empty) | ||
| n.mapChildren(traverse(_, replaceChildren = true, nullIsFalse = false)._1) -> | ||
wangyum marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| mutable.Map.empty | ||
| case o => o -> mutable.Map.empty | ||
| } | ||
|
|
||
| // We need to take into account if an attribute is nullable and the context of the conjunctive | ||
|
|
@@ -198,16 +192,15 @@ object ConstantPropagation extends Rule[LogicalPlan] { | |
| private def safeToReplace(ar: AttributeReference, nullIsFalse: Boolean) = | ||
| !ar.nullable || nullIsFalse | ||
|
|
||
| private def replaceConstants(condition: Expression, equalityPredicates: EqualityPredicates) | ||
| : Expression = { | ||
| val constantsMap = AttributeMap(equalityPredicates.map(_._1)) | ||
| val predicates = equalityPredicates.map(_._2).toSet | ||
| def replaceConstants0(expression: Expression) = expression transform { | ||
| case a: AttributeReference => constantsMap.getOrElse(a, a) | ||
| } | ||
| private def replaceConstants( | ||
| condition: Expression, | ||
| equalityPredicates: EqualityPredicates): Expression = { | ||
| val predicates = equalityPredicates.values.map(_._2).toSet | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we keep the val constantsMap = AttributeMap(equalityPredicates.map { case (attr, (lit, _)) => attr -> lit })
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason for changing The main point of that map is that we store only one |
||
| condition transform { | ||
| case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants0(e) | ||
| case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants0(e) | ||
| case b: BinaryComparison if !predicates.contains(b) => b transform { | ||
| case a: AttributeReference if equalityPredicates.contains(a.canonicalized) => | ||
| equalityPredicates(a.canonicalized)._1 | ||
|
||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
AttributeReferenceis enough.Can we use
AttributeMap? In order to avoid the use ofx.canonicalizedlater:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I deliberately used mutable map here to improve their addition (
equalityPredicates ++= equalityPredicatesRight) later.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a small improvement in cabb083