Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.optimizer

import scala.collection.immutable.HashSet
import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, Stack}
import scala.util.control.NonFatal

Expand Down Expand Up @@ -112,16 +113,13 @@ object ConstantFolding extends Rule[LogicalPlan] {
object ConstantPropagation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsAllPatterns(LITERAL, FILTER), ruleId) {
case f: Filter =>
val (newCondition, _) = traverse(f.condition, replaceChildren = true, nullIsFalse = true)
if (newCondition.isDefined) {
f.copy(condition = newCondition.get)
} else {
f
}
case f: Filter => f.mapExpressions(traverse(_, replaceChildren = true, nullIsFalse = true)._1)
}

type EqualityPredicates = Seq[((AttributeReference, Literal), BinaryComparison)]
// The keys are always canonicalized `AttributeReference`s, but it is easier to use `Expression`
// type keys here instead of casting `AttributeReference.canonicalized` to `AttributeReference` at
// the calling sites.
type EqualityPredicates = mutable.Map[Expression, (Literal, BinaryComparison)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think AttributeReference is enough.

Can we use AttributeMap ? In order to avoid the use of x.canonicalized later:

type EqualityPredicates = AttributeMap[(Literal, BinaryComparison)]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I deliberately used mutable map here to improve their addition (equalityPredicates ++= equalityPredicatesRight) later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a small improvement in cabb083


/**
* Traverse a condition as a tree and replace attributes with constant values.
Expand All @@ -138,56 +136,52 @@ object ConstantPropagation extends Rule[LogicalPlan] {
* case of `WHERE e`, null result of expression `e` means the same as if it
* resulted false
* @return A tuple including:
* 1. Option[Expression]: optional changed condition after traversal
* 1. Expression: optional changed condition after traversal
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the optional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I forgot to update that. Fixed in 0244034

* 2. EqualityPredicates: propagated mapping of attribute => constant
*/
private def traverse(condition: Expression, replaceChildren: Boolean, nullIsFalse: Boolean)
: (Option[Expression], EqualityPredicates) =
private def traverse(
condition: Expression,
replaceChildren: Boolean,
nullIsFalse: Boolean): (Expression, EqualityPredicates) =
condition match {
case e @ EqualTo(left: AttributeReference, right: Literal)
if safeToReplace(left, nullIsFalse) =>
(None, Seq(((left, right), e)))
e -> mutable.Map(left.canonicalized -> (right, e))
case e @ EqualTo(left: Literal, right: AttributeReference)
if safeToReplace(right, nullIsFalse) =>
(None, Seq(((right, left), e)))
e -> mutable.Map(right.canonicalized -> (left, e))
case e @ EqualNullSafe(left: AttributeReference, right: Literal)
if safeToReplace(left, nullIsFalse) =>
(None, Seq(((left, right), e)))
e -> mutable.Map(left.canonicalized -> (right, e))
case e @ EqualNullSafe(left: Literal, right: AttributeReference)
if safeToReplace(right, nullIsFalse) =>
(None, Seq(((right, left), e)))
case a: And =>
val (newLeft, equalityPredicatesLeft) =
traverse(a.left, replaceChildren = false, nullIsFalse)
e -> mutable.Map(right.canonicalized -> (left, e))
case a @ And(left, right) =>
val (newLeft, equalityPredicates) =
traverse(left, replaceChildren = false, nullIsFalse)
val (newRight, equalityPredicatesRight) =
traverse(a.right, replaceChildren = false, nullIsFalse)
val equalityPredicates = equalityPredicatesLeft ++ equalityPredicatesRight
val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) {
Some(And(replaceConstants(newLeft.getOrElse(a.left), equalityPredicates),
replaceConstants(newRight.getOrElse(a.right), equalityPredicates)))
traverse(right, replaceChildren = false, nullIsFalse)
// We could recognize when conflicting constants are coming from the left and right sides
// and immediately shortcut the `And` expression to `Literal.FalseLiteral`, but that case is
// not so common and actually it is the job of `ConstantFolding` and `BooleanSimplification`
// rules to deal with those optimizations.
equalityPredicates ++= equalityPredicatesRight
val newAnd = a.withNewChildren(if (equalityPredicates.nonEmpty && replaceChildren) {
val replacedNewLeft = replaceConstants(newLeft, equalityPredicates)
val replacedNewRight = replaceConstants(newRight, equalityPredicates)
Seq(replacedNewLeft, replacedNewRight)
} else {
if (newLeft.isDefined || newRight.isDefined) {
Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right)))
} else {
None
}
}
(newSelf, equalityPredicates)
Seq(newLeft, newRight)
})
newAnd -> equalityPredicates
case o: Or =>
// Ignore the EqualityPredicates from children since they are only propagated through And.
val (newLeft, _) = traverse(o.left, replaceChildren = true, nullIsFalse)
val (newRight, _) = traverse(o.right, replaceChildren = true, nullIsFalse)
val newSelf = if (newLeft.isDefined || newRight.isDefined) {
Some(Or(left = newLeft.getOrElse(o.left), right = newRight.getOrElse((o.right))))
} else {
None
}
(newSelf, Seq.empty)
o.mapChildren(traverse(_, replaceChildren = true, nullIsFalse)._1) -> mutable.Map.empty
case n: Not =>
// Ignore the EqualityPredicates from children since they are only propagated through And.
val (newChild, _) = traverse(n.child, replaceChildren = true, nullIsFalse = false)
(newChild.map(Not), Seq.empty)
case _ => (None, Seq.empty)
n.mapChildren(traverse(_, replaceChildren = true, nullIsFalse = false)._1) ->
mutable.Map.empty
case o => o -> mutable.Map.empty
}

// We need to take into account if an attribute is nullable and the context of the conjunctive
Expand All @@ -198,16 +192,15 @@ object ConstantPropagation extends Rule[LogicalPlan] {
private def safeToReplace(ar: AttributeReference, nullIsFalse: Boolean) =
!ar.nullable || nullIsFalse

private def replaceConstants(condition: Expression, equalityPredicates: EqualityPredicates)
: Expression = {
val constantsMap = AttributeMap(equalityPredicates.map(_._1))
val predicates = equalityPredicates.map(_._2).toSet
def replaceConstants0(expression: Expression) = expression transform {
case a: AttributeReference => constantsMap.getOrElse(a, a)
}
private def replaceConstants(
condition: Expression,
equalityPredicates: EqualityPredicates): Expression = {
val predicates = equalityPredicates.values.map(_._2).toSet
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep the constantsMap?

val constantsMap = AttributeMap(equalityPredicates.map { case (attr, (lit, _)) => attr -> lit })

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for changing EqualityPredicates to mutable.Map earlier was to avoid building a map here.

The main point of that map is that we store only one Literal (and its original BinaryComparision) assigned to an attribute key. So if we have 2 or more conflicting EqualTo then in replaceConstants() we keep only one's original form and rewrite the other conflicing ones. E.g. a = 1 AND a = 2 we store only a -> (2, a = 2) in the map and rewrite the expression to 2 = 1 AND a = 2.

condition transform {
case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants0(e)
case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants0(e)
case b: BinaryComparison if !predicates.contains(b) => b transform {
case a: AttributeReference if equalityPredicates.contains(a.canonicalized) =>
equalityPredicates(a.canonicalized)._1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

case a: AttributeReference => constantsMap.getOrElse(a, a)

?

Copy link
Contributor Author

@peter-toth peter-toth Mar 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept the current mutable.Map based EqualityPredicates so I chage it to case a: AttributeReference => equalityPredicates.get(a.canonicalized).map(_._1).getOrElse(a).

But if we decide to use AttributeMap then I can change it to the suggested.

}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,9 @@ class ConstantPropagationSuite extends PlanTest {
columnA === Literal(1) && columnA === Literal(2) && columnB === Add(columnA, Literal(3)))

val correctAnswer = testRelation
.select(columnA)
.where(columnA === Literal(1) && columnA === Literal(2) && columnB === Literal(5)).analyze
.select(columnA, columnB)
.where(Literal.FalseLiteral)
.select(columnA).analyze

comparePlans(Optimize.execute(query.analyze), correctAnswer)
}
Expand All @@ -186,4 +187,31 @@ class ConstantPropagationSuite extends PlanTest {
.analyze
comparePlans(Optimize.execute(query2), correctAnswer2)
}

test("SPARK-42500: ConstantPropagation supports more cases") {
comparePlans(
Optimize.execute(testRelation.where(columnA === 1 && columnB > columnA + 2).analyze),
testRelation.where(columnA === 1 && columnB > 3).analyze)

comparePlans(
Optimize.execute(testRelation.where(columnA === 1 && columnA === 2).analyze),
testRelation.where(Literal.FalseLiteral).analyze)

comparePlans(
Optimize.execute(testRelation.where(columnA === 1 && columnA === columnA + 2).analyze),
testRelation.where(Literal.FalseLiteral).analyze)

comparePlans(
Optimize.execute(
testRelation.where((columnA === 1 || columnB === 2) && columnB === 1).analyze),
testRelation.where(columnA === 1 && columnB === 1).analyze)

comparePlans(
Optimize.execute(testRelation.where(columnA === 1 && columnA === 1).analyze),
testRelation.where(columnA === 1).analyze)

comparePlans(
Optimize.execute(testRelation.where(Not(columnA === 1 && columnA === columnA + 2)).analyze),
testRelation.where(Not(columnA === 1) || Not(columnA === columnA + 2)).analyze)
}
}