diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 8cb560199c06..7b44539929c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -122,8 +122,6 @@ object ConstantPropagation extends Rule[LogicalPlan] { } } - type EqualityPredicates = Seq[((AttributeReference, Literal), BinaryComparison)] - /** * Traverse a condition as a tree and replace attributes with constant values. * - On matching [[And]], recursively traverse each children and get propagated mappings. @@ -140,23 +138,23 @@ object ConstantPropagation extends Rule[LogicalPlan] { * resulted false * @return A tuple including: * 1. Option[Expression]: optional changed condition after traversal - * 2. EqualityPredicates: propagated mapping of attribute => constant + * 2. AttributeMap: propagated mapping of attribute => constant */ private def traverse(condition: Expression, replaceChildren: Boolean, nullIsFalse: Boolean) - : (Option[Expression], EqualityPredicates) = + : (Option[Expression], AttributeMap[(Literal, BinaryComparison)]) = condition match { case e @ EqualTo(left: AttributeReference, right: Literal) if safeToReplace(left, nullIsFalse) => - (None, Seq(((left, right), e))) + (None, AttributeMap(Map(left -> (right, e)))) case e @ EqualTo(left: Literal, right: AttributeReference) if safeToReplace(right, nullIsFalse) => - (None, Seq(((right, left), e))) + (None, AttributeMap(Map(right -> (left, e)))) case e @ EqualNullSafe(left: AttributeReference, right: Literal) if safeToReplace(left, nullIsFalse) => - (None, Seq(((left, right), e))) + (None, AttributeMap(Map(left -> (right, e)))) case e @ EqualNullSafe(left: Literal, right: AttributeReference) if safeToReplace(right, nullIsFalse) => - (None, Seq(((right, left), e))) + (None, AttributeMap(Map(right -> (left, e)))) case a: And => val (newLeft, equalityPredicatesLeft) = traverse(a.left, replaceChildren = false, nullIsFalse) @@ -183,12 +181,12 @@ object ConstantPropagation extends Rule[LogicalPlan] { } else { None } - (newSelf, Seq.empty) + (newSelf, AttributeMap.empty) case n: Not => // Ignore the EqualityPredicates from children since they are only propagated through And. val (newChild, _) = traverse(n.child, replaceChildren = true, nullIsFalse = false) - (newChild.map(Not), Seq.empty) - case _ => (None, Seq.empty) + (newChild.map(Not), AttributeMap.empty) + case _ => (None, AttributeMap.empty) } // We need to take into account if an attribute is nullable and the context of the conjunctive @@ -199,16 +197,15 @@ object ConstantPropagation extends Rule[LogicalPlan] { private def safeToReplace(ar: AttributeReference, nullIsFalse: Boolean) = !ar.nullable || nullIsFalse - private def replaceConstants(condition: Expression, equalityPredicates: EqualityPredicates) - : Expression = { - val constantsMap = AttributeMap(equalityPredicates.map(_._1)) - val predicates = equalityPredicates.map(_._2).toSet - def replaceConstants0(expression: Expression) = expression transform { - case a: AttributeReference => constantsMap.getOrElse(a, a) - } + private def replaceConstants( + condition: Expression, + equalityPredicates: AttributeMap[(Literal, BinaryComparison)]): Expression = { + val constantsMap = AttributeMap(equalityPredicates.map { case (attr, (lit, _)) => attr -> lit }) + val predicates = equalityPredicates.values.map(_._2).toSet condition transform { - case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants0(e) - case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants0(e) + case b: BinaryComparison if !predicates.contains(b) => b transform { + case a: AttributeReference => constantsMap.getOrElse(a, a) + } } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala index f5f1455f9461..106af71a9d65 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala @@ -159,8 +159,9 @@ class ConstantPropagationSuite extends PlanTest { columnA === Literal(1) && columnA === Literal(2) && columnB === Add(columnA, Literal(3))) val correctAnswer = testRelation - .select(columnA) - .where(columnA === Literal(1) && columnA === Literal(2) && columnB === Literal(5)).analyze + .select(columnA, columnB) + .where(Literal.FalseLiteral) + .select(columnA).analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } @@ -186,4 +187,31 @@ class ConstantPropagationSuite extends PlanTest { .analyze comparePlans(Optimize.execute(query2), correctAnswer2) } + + test("SPARK-42500: ConstantPropagation supports more cases") { + comparePlans( + Optimize.execute(testRelation.where(columnA === 1 && columnB > columnA + 2).analyze), + testRelation.where(columnA === 1 && columnB > 3).analyze) + + comparePlans( + Optimize.execute(testRelation.where(columnA === 1 && columnA === 2).analyze), + testRelation.where(Literal.FalseLiteral).analyze) + + comparePlans( + Optimize.execute(testRelation.where(columnA === 1 && columnA === columnA + 2).analyze), + testRelation.where(Literal.FalseLiteral).analyze) + + comparePlans( + Optimize.execute( + testRelation.where((columnA === 1 || columnB === 2) && columnB === 1).analyze), + testRelation.where(columnA === 1 && columnB === 1).analyze) + + comparePlans( + Optimize.execute(testRelation.where(columnA === 1 && columnA === 1).analyze), + testRelation.where(columnA === 1).analyze) + + comparePlans( + Optimize.execute(testRelation.where(Not(columnA === 1 && columnA === columnA + 2)).analyze), + testRelation.where(Not(columnA === 1) || Not(columnA === columnA + 2)).analyze) + } }