From dd09f5166c1ad42176aa8cdf63ecfadab51b871a Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Sun, 28 Apr 2019 10:54:02 +0200 Subject: [PATCH 01/12] [SPARK-27604][SQL] Enhance constant propagation to constraint propagation --- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../sql/catalyst/optimizer/expressions.scala | 95 +----- .../catalyst/plans/logical/LogicalPlan.scala | 5 +- .../plans/logical/QueryPlanConstraints.scala | 240 +++++++++++++++ .../optimizer/ConstantPropagationSuite.scala | 289 +++++++++++++++++- .../InferFiltersFromConstraintsSuite.scala | 2 +- .../spark/sql/catalyst/plans/PlanTest.scala | 23 +- 7 files changed, 551 insertions(+), 105 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f32f2c7986dc..741e6ee71b02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -80,7 +80,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) // Constant folding and strength reduction TransposeWindow, NullPropagation, - ConstantPropagation, + ConstraintPropagation, FoldablePropagation, OptimizeIn, ConstantFolding, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 39709529c00d..8108d6fc9994 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -55,100 +55,24 @@ object ConstantFolding extends Rule[LogicalPlan] { } /** - * Substitutes [[Attribute Attributes]] which can be statically evaluated with their corresponding - * value in conjunctive [[Expression Expressions]] + * Substitutes expressions which can be statically narrowed by constrains. * eg. * {{{ - * SELECT * FROM table WHERE i = 5 AND j = i + 3 - * ==> SELECT * FROM table WHERE i = 5 AND j = 8 + * SELECT * FROM table WHERE i = 5 AND j = i + 3 => SELECT * FROM table WHERE i = 5 AND j = 8 + * SELECT * FROM table WHERE i <= 5 AND i = 5 => SELECT * FROM table WHERE i = 5 + * SELECT * FROM table WHERE i < j AND ... AND i > j => SELECT * FROM table WHERE false * }}} - * - * Approach used: - * - Populate a mapping of attribute => constant value by looking at all the equals predicates - * - Using this mapping, replace occurrence of the attributes with the corresponding constant values - * in the AND node. */ -object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper { +object ConstraintPropagation extends Rule[LogicalPlan] with ConstraintHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case f: Filter => - val (newCondition, _) = traverse(f.condition, replaceChildren = true) - if (newCondition.isDefined) { - f.copy(condition = newCondition.get) - } else { + val (newCondition, _) = simplifyWithConstraints(f.condition) + if (newCondition fastEquals f.condition) { f + } else { + f.copy(condition = newCondition) } } - - type EqualityPredicates = Seq[((AttributeReference, Literal), BinaryComparison)] - - /** - * Traverse a condition as a tree and replace attributes with constant values. - * - On matching [[And]], recursively traverse each children and get propagated mappings. - * If the current node is not child of another [[And]], replace all occurrences of the - * attributes with the corresponding constant values. - * - If a child of [[And]] is [[EqualTo]] or [[EqualNullSafe]], propagate the mapping - * of attribute => constant. - * - On matching [[Or]] or [[Not]], recursively traverse each children, propagate empty mapping. - * - Otherwise, stop traversal and propagate empty mapping. - * @param condition condition to be traversed - * @param replaceChildren whether to replace attributes with constant values in children - * @return A tuple including: - * 1. Option[Expression]: optional changed condition after traversal - * 2. EqualityPredicates: propagated mapping of attribute => constant - */ - private def traverse(condition: Expression, replaceChildren: Boolean) - : (Option[Expression], EqualityPredicates) = - condition match { - case e @ EqualTo(left: AttributeReference, right: Literal) => (None, Seq(((left, right), e))) - case e @ EqualTo(left: Literal, right: AttributeReference) => (None, Seq(((right, left), e))) - case e @ EqualNullSafe(left: AttributeReference, right: Literal) => - (None, Seq(((left, right), e))) - case e @ EqualNullSafe(left: Literal, right: AttributeReference) => - (None, Seq(((right, left), e))) - case a: And => - val (newLeft, equalityPredicatesLeft) = traverse(a.left, replaceChildren = false) - val (newRight, equalityPredicatesRight) = traverse(a.right, replaceChildren = false) - val equalityPredicates = equalityPredicatesLeft ++ equalityPredicatesRight - val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) { - Some(And(replaceConstants(newLeft.getOrElse(a.left), equalityPredicates), - replaceConstants(newRight.getOrElse(a.right), equalityPredicates))) - } else { - if (newLeft.isDefined || newRight.isDefined) { - Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right))) - } else { - None - } - } - (newSelf, equalityPredicates) - case o: Or => - // Ignore the EqualityPredicates from children since they are only propagated through And. - val (newLeft, _) = traverse(o.left, replaceChildren = true) - val (newRight, _) = traverse(o.right, replaceChildren = true) - val newSelf = if (newLeft.isDefined || newRight.isDefined) { - Some(Or(left = newLeft.getOrElse(o.left), right = newRight.getOrElse((o.right)))) - } else { - None - } - (newSelf, Seq.empty) - case n: Not => - // Ignore the EqualityPredicates from children since they are only propagated through And. - val (newChild, _) = traverse(n.child, replaceChildren = true) - (newChild.map(Not), Seq.empty) - case _ => (None, Seq.empty) - } - - private def replaceConstants(condition: Expression, equalityPredicates: EqualityPredicates) - : Expression = { - val constantsMap = AttributeMap(equalityPredicates.map(_._1)) - val predicates = equalityPredicates.map(_._2).toSet - def replaceConstants0(expression: Expression) = expression transform { - case a: AttributeReference => constantsMap.getOrElse(a, a) - } - condition transform { - case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants0(e) - case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants0(e) - } - } } /** @@ -389,6 +313,7 @@ object SimplifyBinaryComparison extends Rule[LogicalPlan] with PredicateHelper { case q: LogicalPlan => q transformExpressionsUp { // True with equality case a EqualNullSafe b if a.semanticEquals(b) => TrueLiteral + case a EqualNullSafe b if a.foldable || b.foldable => EqualTo(a, b) case a EqualTo b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral case a GreaterThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index db272b3d415e..9bea4cfac3c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -173,14 +173,15 @@ abstract class UnaryNode extends LogicalPlan { var allConstraints = child.constraints.asInstanceOf[Set[Expression]] projectList.foreach { case a @ Alias(l: Literal, _) => - allConstraints += EqualNullSafe(a.toAttribute, l) + allConstraints += EqualTo(a.toAttribute, l) case a @ Alias(e, _) => // For every alias in `projectList`, replace the reference in constraints by its attribute. allConstraints ++= allConstraints.map(_ transform { case expr: Expression if expr.semanticEquals(e) => a.toAttribute }) - allConstraints += EqualNullSafe(e, a.toAttribute) + allConstraints += + (if (e.foldable) EqualTo(e, a.toAttribute) else EqualNullSafe(e, a.toAttribute)) case _ => // Don't change. } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index cc352c59dff8..08687bd855c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -122,4 +122,244 @@ trait ConstraintHelper { case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute) case _ => Seq.empty[Attribute] } + + /** + * Traverse a condition as a tree and simplify expressions with constraints. + * - On matching [[And]], recursively traverse both children, simplify child expressions with + * propagated constraints from sibling and propagate up union of constraints. + * - If a child of [[And]] is [[LessThan]], [[LessThanOrEqual]], [[EqualTo]], [[EqualNullSafe]], + * [[GreaterThan]] or [[GreaterThanOrEqual]] propagate the constraint. + * - On matching [[Or]] or [[Not]], recursively traverse each children, propagate no constraints. + * - Otherwise, stop traversal and propagate no constraints. + * @param expression expression to be traversed + * @return A tuple including: + * 1. Expression: optionally changed condition after traversal + * 2. Seq[Expression]: propagated constraints + */ + def simplifyWithConstraints(expression: Expression): (Expression, Seq[Expression]) = + expression match { + case e @ (_: LessThan | _: LessThanOrEqual | _: EqualTo | _: EqualNullSafe | _: GreaterThan | + _: GreaterThanOrEqual ) + if e.deterministic => (e, Seq(normalize(e))) + case a @ And(left, right) => + val (newLeft, leftConstraints) = simplifyWithConstraints(left) + val simplifiedRight = simplify(right, leftConstraints) + val (simplifiedNewRight, rightConstraints) = simplifyWithConstraints(simplifiedRight) + val simplifiedNewLeft = simplify(newLeft, rightConstraints) + val newAnd = if ((simplifiedNewLeft fastEquals left) && + (simplifiedNewRight fastEquals right)) { + a + } else { + And(simplifiedNewLeft, simplifiedNewRight) + } + (newAnd, leftConstraints ++ rightConstraints) + case o @ Or(left, right) => + // Ignore the EqualityPredicates from children since they are only propagated through And. + val (newLeft, _) = simplifyWithConstraints(left) + val (newRight, _) = simplifyWithConstraints(right) + val newOr = if ((newLeft fastEquals left) && (newRight fastEquals right)) { + o + } else { + Or(newLeft, newRight) + } + + (newOr, Seq.empty) + case n @ Not(child) => + // Ignore the EqualityPredicates from children since they are only propagated through And. + val (newChild, _) = simplifyWithConstraints(child) + val newNot = if (newChild fastEquals child) { + n + } else { + Not(newChild) + } + (newNot, Seq.empty) + case _ => (expression, Seq.empty) + } + + private def normalize(expression: Expression) = expression transform { + case GreaterThan(x, y) => LessThan(y, x) + case GreaterThanOrEqual(x, y) => LessThanOrEqual(y, x) + } + + private def simplify(expression: Expression, constraints: Seq[Expression]): Expression = + constraints.foldLeft(normalize(expression))((e, constraint) => simplify(e, constraint)) + + private def planEqual(x: Expression, y: Expression) = + !x.foldable && !y.foldable && x.canonicalized == y.canonicalized + + private def valueEqual(x: Expression, y: Expression) = + x.foldable && y.foldable && EqualTo(x, y).eval(EmptyRow).asInstanceOf[Boolean] + + private def valueLessThan(x: Expression, y: Expression) = + x.foldable && y.foldable && LessThan(x, y).eval(EmptyRow).asInstanceOf[Boolean] + + private def valueLessThanOrEqual(x: Expression, y: Expression) = + x.foldable && y.foldable && LessThanOrEqual(x, y).eval(EmptyRow).asInstanceOf[Boolean] + + private def simplify(expression: Expression, constraint: Expression): Expression = + constraint match { + case a LessThan b => expression transformUp { + case c LessThan d if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => + Literal.TrueLiteral + case c LessThan d if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => + Literal.FalseLiteral + case c LessThan d if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => + Literal.TrueLiteral + case c LessThan d if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => + Literal.FalseLiteral + + case c LessThanOrEqual d + if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => + Literal.TrueLiteral + case c LessThanOrEqual d + if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => + Literal.FalseLiteral + case c LessThanOrEqual d + if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => + Literal.TrueLiteral + case c LessThanOrEqual d + if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => + Literal.FalseLiteral + + case c EqualTo d if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => + Literal.FalseLiteral + case c EqualTo d if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => + Literal.FalseLiteral + case c EqualTo d if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => + Literal.FalseLiteral + case c EqualTo d if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => + Literal.FalseLiteral + + case c EqualNullSafe d if planEqual(b, d) && planEqual(a, c) => + Literal.FalseLiteral + case c EqualNullSafe d if planEqual(b, c) && planEqual(a, d) => + Literal.FalseLiteral + case c EqualNullSafe d if planEqual(a, c) && planEqual(b, d) => + Literal.FalseLiteral + case c EqualNullSafe d if planEqual(a, d) && planEqual(b, c) => + Literal.FalseLiteral + + case c EqualNullSafe d if planEqual(b, d) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(b, c) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, c) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, d) => EqualTo(c, d) + } + case a LessThanOrEqual b => expression transformUp { + case c LessThan d if planEqual(b, d) && valueLessThan(c, a) => + Literal.TrueLiteral + case c LessThan d if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => + Literal.FalseLiteral + case c LessThan d if planEqual(a, c) && valueLessThan(b, d) => + Literal.TrueLiteral + case c LessThan d if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => + Literal.FalseLiteral + + case c LessThanOrEqual d + if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => + Literal.TrueLiteral + case c LessThanOrEqual d if planEqual(b, c) && valueLessThan(d, a) => + Literal.FalseLiteral + case c LessThanOrEqual d if planEqual(b, c) && (planEqual(a, d) || valueEqual(a, d)) => + EqualTo(c, d) + case c LessThanOrEqual d + if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => + Literal.TrueLiteral + case c LessThanOrEqual d if planEqual(a, d) && valueLessThan(b, c) => + Literal.FalseLiteral + case c LessThanOrEqual d if planEqual(a, d) && (planEqual(b, c) || valueEqual(b, c)) => + EqualTo(c, d) + + case c EqualTo d if planEqual(b, d) && valueLessThan(c, a) => + Literal.FalseLiteral + case c EqualTo d if planEqual(b, c) && valueLessThan(d, a) => + Literal.FalseLiteral + case c EqualTo d if planEqual(a, c) && valueLessThan(b, d) => + Literal.FalseLiteral + case c EqualTo d if planEqual(a, d) && valueLessThan(b, c) => + Literal.FalseLiteral + + case c EqualNullSafe d if planEqual(b, d) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(b, c) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, c) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, d) => EqualTo(c, d) + } + case a EqualTo b => + if (b.foldable) { + expression transformUp { case c if planEqual(a, c) => b } + } else if (a.foldable) { + expression transformUp { case c if planEqual(b, c) => a } + } else { + expression transformUp { + case c LessThan d if planEqual(b, d) && planEqual(a, c) => + Literal.FalseLiteral + case c LessThan d if planEqual(b, c) && planEqual(a, d) => + Literal.FalseLiteral + case c LessThan d if planEqual(a, d) && planEqual(b, c) => + Literal.FalseLiteral + case c LessThan d if planEqual(a, c) && planEqual(b, d) => + Literal.FalseLiteral + + case c LessThanOrEqual d if planEqual(b, d) && planEqual(a, c) => + Literal.TrueLiteral + case c LessThanOrEqual d if planEqual(b, c) && planEqual(a, d) => + Literal.TrueLiteral + case c LessThanOrEqual d if planEqual(a, d) && planEqual(b, c) => + Literal.TrueLiteral + case c LessThanOrEqual d if planEqual(a, c) && planEqual(b, d) => + Literal.TrueLiteral + + case c EqualTo d if planEqual(b, d) && planEqual(a, c) => + Literal.TrueLiteral + case c EqualTo d if planEqual(b, c) && planEqual(a, d) => + Literal.TrueLiteral + case c EqualTo d if planEqual(a, d) && planEqual(b, c) => + Literal.TrueLiteral + case c EqualTo d if planEqual(a, c) && planEqual(b, d) => + Literal.TrueLiteral + + case c EqualNullSafe d if planEqual(b, d) && planEqual(a, c) => + Literal.TrueLiteral + case c EqualNullSafe d if planEqual(b, c) && planEqual(a, d) => + Literal.TrueLiteral + case c EqualNullSafe d if planEqual(a, d) && planEqual(b, c) => + Literal.TrueLiteral + case c EqualNullSafe d if planEqual(a, c) && planEqual(b, d) => + Literal.TrueLiteral + + case c EqualNullSafe d if planEqual(b, d) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(b, c) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, c) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, d) => EqualTo(c, d) + } + } + case a EqualNullSafe b => expression transformUp { + case c LessThan d if planEqual(b, d) && planEqual(a, c) => + Literal.FalseLiteral + case c LessThan d if planEqual(b, c) && planEqual(d, a) => + Literal.FalseLiteral + case c LessThan d if planEqual(a, d) && planEqual(b, c) => + Literal.FalseLiteral + case c LessThan d if planEqual(a, c) && planEqual(d, b) => + Literal.FalseLiteral + + case c LessThanOrEqual d if planEqual(b, d) && planEqual(a, c) => + EqualTo(c, d) + case c LessThanOrEqual d if planEqual(b, c) && planEqual(a, d) => + EqualTo(c, d) + case c LessThanOrEqual d if planEqual(a, d) && planEqual(b, c) => + EqualTo(c, d) + case c LessThanOrEqual d if planEqual(a, c) && planEqual(b, d) => + EqualTo(c, d) + + case c EqualNullSafe d if planEqual(b, d) && planEqual(a, c) => + Literal.TrueLiteral + case c EqualNullSafe d if planEqual(b, c) && planEqual(a, d) => + Literal.TrueLiteral + case c EqualNullSafe d if planEqual(a, d) && planEqual(b, c) => + Literal.TrueLiteral + case c EqualNullSafe d if planEqual(a, c) && planEqual(b, d) => + Literal.TrueLiteral + } + case _ => expression + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala index 94174eec8fd0..2dce72505445 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -35,12 +36,28 @@ class ConstantPropagationSuite extends PlanTest { Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("ConstantPropagation", FixedPoint(10), - ConstantPropagation, + ConstraintPropagation, ConstantFolding, - BooleanSimplification) :: Nil + BooleanSimplification, + SimplifyBinaryComparison, + PruneFilters) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'x.boolean, 'y.boolean) + + val data = { + val intElements = Seq(null, 1, 2, 3) + val booleanElements = Seq(null, true, false) + for { + a <- intElements + b <- intElements + c <- intElements + x <- booleanElements + y <- booleanElements + } yield (a, b, c, x, y) + } + + val testRelationWithData = LocalRelation.fromExternalRows(testRelation.output, data.map(Row(_))) private val columnA = 'a private val columnB = 'b @@ -106,14 +123,12 @@ class ConstantPropagationSuite extends PlanTest { test("equality predicates outside a `OR` can be propagated within a `OR`") { val query = testRelation - .select(columnA) .where( columnA === Literal(2) && (columnA === Add(columnB, Literal(3)) || columnB === Literal(9))) .analyze val correctAnswer = testRelation - .select(columnA) .where( columnA === Literal(2) && (Literal(2) === Add(columnB, Literal(3)) || columnB === Literal(9))) @@ -153,14 +168,270 @@ class ConstantPropagationSuite extends PlanTest { test("conflicting equality predicates") { val query = testRelation - .select(columnA) .where( columnA === Literal(1) && columnA === Literal(2) && columnB === Add(columnA, Literal(3))) - val correctAnswer = testRelation - .select(columnA) - .where(columnA === Literal(1) && columnA === Literal(2) && columnB === Literal(5)).analyze + val correctAnswer = testRelation.analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } + + private def testPropagation( + input: Expression, + expectEmptyRelation: Boolean, + expectedConstraints: Seq[Expression] = Seq.empty) = { + val originalQuery = testRelationWithData.where(input).analyze + val optimized = Optimize.execute(originalQuery) + val correctAnswer = if (expectEmptyRelation) { + testRelation + } else { + testRelationWithData.where(expectedConstraints.reduce(And)).analyze + } + comparePlans(optimized, correctAnswer) + } + + test("Enhanced constraint propagation") { + testPropagation('a < 2 && Literal(2) < 'a, true) + testPropagation('a < 2 && Literal(2) <= 'a, true) + testPropagation('a < 2 && Literal(2) === 'a, true) + testPropagation('a < 2 && Literal(2) <=> 'a, true) + testPropagation('a < 2 && Literal(2) >= 'a, false, Seq('a < 2)) + testPropagation('a < 2 && Literal(2) > 'a, false, Seq('a < 2)) + testPropagation('a <= 2 && Literal(2) < 'a, true) + testPropagation('a <= 2 && Literal(2) <= 'a, false, Seq('a === 2)) + testPropagation('a <= 2 && Literal(2) === 'a, false, Seq('a === 2)) + testPropagation('a <= 2 && Literal(2) <=> 'a, false, Seq('a === 2)) + testPropagation('a <= 2 && Literal(2) >= 'a, false, Seq('a <= 2)) + testPropagation('a <= 2 && Literal(2) > 'a, false, Seq('a < 2)) + testPropagation('a === 2 && Literal(2) < 'a, true) + testPropagation('a === 2 && Literal(2) <= 'a, false, Seq('a === 2)) + testPropagation('a === 2 && Literal(2) === 'a, false, Seq('a === 2)) + testPropagation('a === 2 && Literal(2) <=> 'a, false, Seq('a === 2)) + testPropagation('a === 2 && Literal(2) >= 'a, false, Seq('a === 2)) + testPropagation('a === 2 && Literal(2) > 'a, true) + testPropagation('a <=> 2 && Literal(2) < 'a, true) + testPropagation('a <=> 2 && Literal(2) <= 'a, false, Seq('a === 2)) + testPropagation('a <=> 2 && Literal(2) === 'a, false, Seq('a === 2)) + testPropagation('a <=> 2 && Literal(2) <=> 'a, false, Seq('a === 2)) + testPropagation('a <=> 2 && Literal(2) >= 'a, false, Seq('a === 2)) + testPropagation('a <=> 2 && Literal(2) > 'a, true) + testPropagation('a >= 2 && Literal(2) < 'a, false, Seq(Literal(2) < 'a)) + testPropagation('a >= 2 && Literal(2) <= 'a, false, Seq(Literal(2) <= 'a)) + testPropagation('a >= 2 && Literal(2) === 'a, false, Seq('a === 2)) + testPropagation('a >= 2 && Literal(2) <=> 'a, false, Seq('a === 2)) + testPropagation('a >= 2 && Literal(2) >= 'a, false, Seq('a === 2)) + testPropagation('a >= 2 && Literal(2) > 'a, true) + testPropagation('a > 2 && Literal(2) < 'a, false, Seq(Literal(2) < 'a)) + testPropagation('a > 2 && Literal(2) <= 'a, false, Seq(Literal(2) < 'a)) + testPropagation('a > 2 && Literal(2) === 'a, true) + testPropagation('a > 2 && Literal(2) <=> 'a, true) + testPropagation('a > 2 && Literal(2) >= 'a, true) + testPropagation('a > 2 && Literal(2) > 'a, true) + + testPropagation('a < 2 && Literal(3) < 'a, true) + testPropagation('a < 2 && Literal(3) <= 'a, true) + testPropagation('a < 2 && Literal(3) === 'a, true) + testPropagation('a < 2 && Literal(3) <=> 'a, true) + testPropagation('a < 2 && Literal(3) >= 'a, false, Seq('a < 2)) + testPropagation('a < 2 && Literal(3) > 'a, false, Seq('a < 2)) + testPropagation('a <= 2 && Literal(3) < 'a, true) + testPropagation('a <= 2 && Literal(3) <= 'a, true) + testPropagation('a <= 2 && Literal(3) === 'a, true) + testPropagation('a <= 2 && Literal(3) <=> 'a, true) + testPropagation('a <= 2 && Literal(3) >= 'a, false, Seq('a <= 2)) + testPropagation('a <= 2 && Literal(3) > 'a, false, Seq('a <= 2)) + testPropagation('a === 2 && Literal(3) < 'a, true) + testPropagation('a === 2 && Literal(3) <= 'a, true) + testPropagation('a === 2 && Literal(3) === 'a, true) + testPropagation('a === 2 && Literal(3) <=> 'a, true) + testPropagation('a === 2 && Literal(3) >= 'a, false, Seq('a === 2)) + testPropagation('a === 2 && Literal(3) > 'a, false, Seq('a === 2)) + testPropagation('a <=> 2 && Literal(3) < 'a, true) + testPropagation('a <=> 2 && Literal(3) <= 'a, true) + testPropagation('a <=> 2 && Literal(3) === 'a, true) + testPropagation('a <=> 2 && Literal(3) <=> 'a, true) + testPropagation('a <=> 2 && Literal(3) >= 'a, false, Seq('a === 2)) + testPropagation('a <=> 2 && Literal(3) > 'a, false, Seq('a === 2)) + testPropagation('a >= 2 && Literal(3) < 'a, false, Seq(Literal(3) < 'a)) + testPropagation('a >= 2 && Literal(3) <= 'a, false, Seq(Literal(3) <= 'a)) + testPropagation('a >= 2 && Literal(3) === 'a, false, Seq(Literal(3) === 'a)) + testPropagation('a >= 2 && Literal(3) <=> 'a, false, Seq(Literal(3) === 'a)) + testPropagation('a >= 2 && Literal(3) >= 'a, false, Seq(Literal(2) <= 'a, 'a <= Literal(3))) + testPropagation('a >= 2 && Literal(3) > 'a, false, Seq(Literal(2) <= 'a, 'a < Literal(3))) + testPropagation('a > 2 && Literal(3) < 'a, false, Seq(Literal(3) < 'a)) + testPropagation('a > 2 && Literal(3) <= 'a, false, Seq(Literal(3) <= 'a)) + testPropagation('a > 2 && Literal(3) === 'a, false, Seq(Literal(3) === 'a)) + testPropagation('a > 2 && Literal(3) <=> 'a, false, Seq(Literal(3) === 'a)) + testPropagation('a > 2 && Literal(3) >= 'a, false, Seq(Literal(2) < 'a, 'a <= Literal(3))) + testPropagation('a > 2 && Literal(3) > 'a, false, Seq(Literal(2) < 'a, 'a < Literal(3))) + + testPropagation('a < 3 && Literal(2) < 'a, false, Seq('a < 3 && Literal(2) < 'a)) + testPropagation('a < 3 && Literal(2) <= 'a, false, Seq('a < 3 && Literal(2) <= 'a)) + testPropagation('a < 3 && Literal(2) === 'a, false, Seq('a === 2)) + testPropagation('a < 3 && Literal(2) <=> 'a, false, Seq('a === 2)) + testPropagation('a < 3 && Literal(2) >= 'a, false, Seq('a <= 2)) + testPropagation('a < 3 && Literal(2) > 'a, false, Seq('a < 2)) + testPropagation('a <= 3 && Literal(2) < 'a, false, Seq('a <= 3 && Literal(2) < 'a)) + testPropagation('a <= 3 && Literal(2) <= 'a, false, Seq('a <= 3 && Literal(2) <= 'a)) + testPropagation('a <= 3 && Literal(2) === 'a, false, Seq('a === 2)) + testPropagation('a <= 3 && Literal(2) <=> 'a, false, Seq('a === 2)) + testPropagation('a <= 3 && Literal(2) >= 'a, false, Seq('a <= 2)) + testPropagation('a <= 3 && Literal(2) > 'a, false, Seq('a < 2)) + testPropagation('a === 3 && Literal(2) < 'a, false, Seq('a === 3)) + testPropagation('a === 3 && Literal(2) <= 'a, false, Seq('a === 3)) + testPropagation('a === 3 && Literal(2) === 'a, true) + testPropagation('a === 3 && Literal(2) <=> 'a, true) + testPropagation('a === 3 && Literal(2) >= 'a, true) + testPropagation('a === 3 && Literal(2) > 'a, true) + testPropagation('a <=> 3 && Literal(2) < 'a, false, Seq('a === 3)) + testPropagation('a <=> 3 && Literal(2) <= 'a, false, Seq('a === 3)) + testPropagation('a <=> 3 && Literal(2) === 'a, true) + testPropagation('a <=> 3 && Literal(2) <=> 'a, true) + testPropagation('a <=> 3 && Literal(2) >= 'a, true) + testPropagation('a <=> 3 && Literal(2) > 'a, true) + testPropagation('a >= 3 && Literal(2) < 'a, false, Seq(Literal(3) <= 'a)) + testPropagation('a >= 3 && Literal(2) <= 'a, false, Seq(Literal(3) <= 'a)) + testPropagation('a >= 3 && Literal(2) === 'a, true) + testPropagation('a >= 3 && Literal(2) <=> 'a, true) + testPropagation('a >= 3 && Literal(2) >= 'a, true) + testPropagation('a >= 3 && Literal(2) > 'a, true) + testPropagation('a > 3 && Literal(2) < 'a, false, Seq(Literal(3) < 'a)) + testPropagation('a > 3 && Literal(2) <= 'a, false, Seq(Literal(3) < 'a)) + testPropagation('a > 3 && Literal(2) === 'a, true) + testPropagation('a > 3 && Literal(2) <=> 'a, true) + testPropagation('a > 3 && Literal(2) >= 'a, true) + testPropagation('a > 3 && Literal(2) > 'a, true) + + testPropagation('a < 'b && 'b < 'a, true) + testPropagation('a < 'b && 'b <= 'a, true) + testPropagation('a < 'b && 'b === 'a, true) + testPropagation('a < 'b && 'b <=> 'a, true) + testPropagation('a < 'b && 'b >= 'a, false, Seq('a < 'b)) + testPropagation('a < 'b && 'b > 'a, false, Seq('a < 'b)) + testPropagation('a <= 'b && 'b < 'a, true) + testPropagation('a <= 'b && 'b <= 'a, false, Seq('a === 'b)) + testPropagation('a <= 'b && 'b === 'a, false, Seq('a === 'b)) + testPropagation('a <= 'b && 'b <=> 'a, false, Seq('a === 'b)) + testPropagation('a <= 'b && 'b >= 'a, false, Seq('a <= 'b)) + testPropagation('a <= 'b && 'b > 'a, false, Seq('a < 'b)) + testPropagation('a === 'b && 'b < 'a, true) + testPropagation('a === 'b && 'b <= 'a, false, Seq('a === 'b)) + testPropagation('a === 'b && 'b === 'a, false, Seq('a === 'b)) + testPropagation('a === 'b && 'b <=> 'a, false, Seq('a === 'b)) + testPropagation('a === 'b && 'b >= 'a, false, Seq('a === 'b)) + testPropagation('a === 'b && 'b > 'a, true) + testPropagation('a <=> 'b && 'b < 'a, true) + testPropagation('a <=> 'b && 'b <= 'a, false, Seq('a === 'b)) + testPropagation('a <=> 'b && 'b === 'a, false, Seq('a === 'b)) + testPropagation('a <=> 'b && 'b <=> 'a, false, Seq('a <=> 'b)) + testPropagation('a <=> 'b && 'b >= 'a, false, Seq('a === 'b)) + testPropagation('a <=> 'b && 'b > 'a, true) + testPropagation('a >= 'b && 'b < 'a, false, Seq('b < 'a)) + testPropagation('a >= 'b && 'b <= 'a, false, Seq('b <= 'a)) + testPropagation('a >= 'b && 'b === 'a, false, Seq('a === 'b)) + testPropagation('a >= 'b && 'b <=> 'a, false, Seq('a === 'b)) + testPropagation('a >= 'b && 'b >= 'a, false, Seq('a === 'b)) + testPropagation('a >= 'b && 'b > 'a, true) + testPropagation('a > 'b && 'b < 'a, false, Seq('b < 'a)) + testPropagation('a > 'b && 'b <= 'a, false, Seq('b < 'a)) + testPropagation('a > 'b && 'b === 'a, true) + testPropagation('a > 'b && 'b <=> 'a, true) + testPropagation('a > 'b && 'b >= 'a, true) + testPropagation('a > 'b && 'b > 'a, true) + + testPropagation('a < abs('b) && abs('b) < 'a, true) + testPropagation('a < abs('b) && abs('b) <= 'a, true) + testPropagation('a < abs('b) && abs('b) === 'a, true) + testPropagation('a < abs('b) && abs('b) <=> 'a, true) + testPropagation('a < abs('b) && abs('b) >= 'a, false, Seq('a < abs('b))) + testPropagation('a < abs('b) && abs('b) > 'a, false, Seq('a < abs('b))) + testPropagation('a <= abs('b) && abs('b) < 'a, true) + testPropagation('a <= abs('b) && abs('b) <= 'a, false, Seq('a === abs('b))) + testPropagation('a <= abs('b) && abs('b) === 'a, false, Seq('a === abs('b))) + testPropagation('a <= abs('b) && abs('b) <=> 'a, false, Seq('a === abs('b))) + testPropagation('a <= abs('b) && abs('b) >= 'a, false, Seq('a <= abs('b))) + testPropagation('a <= abs('b) && abs('b) > 'a, false, Seq('a < abs('b))) + testPropagation('a === abs('b) && abs('b) < 'a, true) + testPropagation('a === abs('b) && abs('b) <= 'a, false, Seq('a === abs('b))) + testPropagation('a === abs('b) && abs('b) === 'a, false, Seq('a === abs('b))) + testPropagation('a === abs('b) && abs('b) <=> 'a, false, Seq('a === abs('b))) + testPropagation('a === abs('b) && abs('b) >= 'a, false, Seq('a === abs('b))) + testPropagation('a === abs('b) && abs('b) > 'a, true) + testPropagation('a <=> abs('b) && abs('b) < 'a, true) + testPropagation('a <=> abs('b) && abs('b) <= 'a, false, Seq('a === abs('b))) + testPropagation('a <=> abs('b) && abs('b) === 'a, false, Seq('a === abs('b))) + testPropagation('a <=> abs('b) && abs('b) <=> 'a, false, Seq('a <=> abs('b))) + testPropagation('a <=> abs('b) && abs('b) >= 'a, false, Seq('a === abs('b))) + testPropagation('a <=> abs('b) && abs('b) > 'a, true) + testPropagation('a >= abs('b) && abs('b) < 'a, false, Seq(abs('b) < 'a)) + testPropagation('a >= abs('b) && abs('b) <= 'a, false, Seq(abs('b) <= 'a)) + testPropagation('a >= abs('b) && abs('b) === 'a, false, Seq('a === abs('b))) + testPropagation('a >= abs('b) && abs('b) <=> 'a, false, Seq('a === abs('b))) + testPropagation('a >= abs('b) && abs('b) >= 'a, false, Seq('a === abs('b))) + testPropagation('a >= abs('b) && abs('b) > 'a, true) + testPropagation('a > abs('b) && abs('b) < 'a, false, Seq(abs('b) < 'a)) + testPropagation('a > abs('b) && abs('b) <= 'a, false, Seq(abs('b) < 'a)) + testPropagation('a > abs('b) && abs('b) === 'a, true) + testPropagation('a > abs('b) && abs('b) <=> 'a, true) + testPropagation('a > abs('b) && abs('b) >= 'a, true) + testPropagation('a > abs('b) && abs('b) > 'a, true) + + testPropagation(('x || 'a < abs('b) || 'y) && abs('b) < 'a, false, Seq('x || 'y, abs('b) < 'a)) + testPropagation(('x || 'a < abs('b) || 'y) && abs('b) <= 'a, false, + Seq('x || 'y, abs('b) <= 'a)) + testPropagation(('x || 'a < abs('b) || 'y) && abs('b) === 'a, false, + Seq('x || 'y, abs('b) === 'a)) + testPropagation(('x || 'a < abs('b) || 'y) && abs('b) <=> 'a, false, + Seq('x || 'y, abs('b) <=> 'a)) + testPropagation(('x || 'a < abs('b) || 'y) && abs('b) >= 'a, false, + Seq('x || 'a < abs('b) || 'y, 'a <= abs('b))) + testPropagation(('x || 'a < abs('b) || 'y) && abs('b) > 'a, false, Seq('a < abs('b))) + testPropagation(('x || 'a <= abs('b) || 'y) && abs('b) < 'a, false, Seq('x || 'y, abs('b) < 'a)) + testPropagation(('x || 'a <= abs('b) || 'y) && abs('b) <= 'a, false, + Seq('x || 'a === abs('b) || 'y, abs('b) <= 'a)) + testPropagation(('x || 'a <= abs('b) || 'y) && abs('b) === 'a, false, Seq(abs('b) === 'a)) + testPropagation(('x || 'a <= abs('b) || 'y) && abs('b) <=> 'a, false, + Seq('x || 'a === abs('b) || 'y, abs('b) <=> 'a)) + testPropagation(('x || 'a <= abs('b) || 'y) && abs('b) >= 'a, false, Seq('a <= abs('b))) + testPropagation(('x || 'a <= abs('b) || 'y) && abs('b) > 'a, false, Seq('a < abs('b))) + testPropagation(('x || 'a === abs('b) || 'y) && abs('b) < 'a, false, + Seq('x || 'y, abs('b) < 'a)) + testPropagation(('x || 'a === abs('b) || 'y) && abs('b) <= 'a, false, + Seq('x || 'a === abs('b) || 'y, abs('b) <= 'a)) + testPropagation(('x || 'a === abs('b) || 'y) && abs('b) === 'a, false, Seq(abs('b) === 'a)) + testPropagation(('x || 'a === abs('b) || 'y) && abs('b) <=> 'a, false, + Seq('x || 'a === abs('b) || 'y, abs('b) <=> 'a)) + testPropagation(('x || 'a === abs('b) || 'y) && abs('b) >= 'a, false, + Seq('x || 'a === abs('b) || 'y, 'a <= abs('b))) + testPropagation(('x || 'a === abs('b) || 'y) && abs('b) > 'a, false, + Seq('x || 'y, 'a < abs('b))) + testPropagation(('x || 'a <=> abs('b) || 'y) && abs('b) < 'a, false, + Seq('x || 'y, abs('b) < 'a)) + testPropagation(('x || 'a <=> abs('b) || 'y) && abs('b) <= 'a, false, + Seq('x || 'a === abs('b) || 'y, abs('b) <= 'a)) + testPropagation(('x || 'a <=> abs('b) || 'y) && abs('b) === 'a, false, Seq(abs('b) === 'a)) + testPropagation(('x || 'a <=> abs('b) || 'y) && abs('b) <=> 'a, false, Seq(abs('b) <=> 'a)) + testPropagation(('x || 'a <=> abs('b) || 'y) && abs('b) >= 'a, false, + Seq('x || 'a === abs('b) || 'y, 'a <= abs('b))) + testPropagation(('x || 'a <=> abs('b) || 'y) && abs('b) > 'a, false, + Seq('x || 'y, 'a < abs('b))) + testPropagation(('x || 'a >= abs('b) || 'y) && abs('b) < 'a, false, Seq(abs('b) < 'a)) + testPropagation(('x || 'a >= abs('b) || 'y) && abs('b) <= 'a, false, Seq(abs('b) <= 'a)) + testPropagation(('x || 'a >= abs('b) || 'y) && abs('b) === 'a, false, Seq(abs('b) === 'a)) + testPropagation(('x || 'a >= abs('b) || 'y) && abs('b) <=> 'a, false, + Seq('x || 'a === abs('b) || 'y, abs('b) <=> 'a)) + testPropagation(('x || 'a >= abs('b) || 'y) && abs('b) >= 'a, false, + Seq('x || 'a === abs('b) || 'y, 'a <= abs('b))) + testPropagation(('x || 'a >= abs('b) || 'y) && abs('b) > 'a, false, Seq('x || 'y, 'a < abs('b))) + testPropagation(('x || 'a > abs('b) || 'y) && abs('b) < 'a, false, Seq(abs('b) < 'a)) + testPropagation(('x || 'a > abs('b) || 'y) && abs('b) <= 'a, false, + Seq('x || abs('b) < 'a || 'y, abs('b) <= 'a)) + testPropagation(('x || 'a > abs('b) || 'y) && abs('b) === 'a, false, + Seq('x || 'y, abs('b) === 'a)) + testPropagation(('x || 'a > abs('b) || 'y) && abs('b) <=> 'a, false, + Seq('x || 'y, abs('b) <=> 'a)) + testPropagation(('x || 'a > abs('b) || 'y) && abs('b) >= 'a, false, + Seq('x || 'y, 'a <= abs('b))) + testPropagation(('x || 'a > abs('b) || 'y) && abs('b) > 'a, false, Seq('x || 'y, 'a < abs('b))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index a40ba2dc38b7..e4671f0d1cce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -196,7 +196,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { test("constraints should be inferred from aliased literals") { val originalLeft = testRelation.subquery('left).as("left") - val optimizedLeft = testRelation.subquery('left).where(IsNotNull('a) && 'a <=> 2).as("left") + val optimizedLeft = testRelation.subquery('left).where(IsNotNull('a) && 'a === 2).as("left") val right = Project(Seq(Literal(2).as("two")), testRelation.subquery('right)).as("right") val condition = Some("left.a".attr === "right.two".attr) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 5394732f41f2..ecb69c109951 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -95,13 +95,19 @@ trait PlanTestBase extends PredicateHelper with SQLHelper { self: Suite => protected def normalizePlan(plan: LogicalPlan): LogicalPlan = { plan transform { case Filter(condition: Expression, child: LogicalPlan) => - Filter(splitConjunctivePredicates(condition).map(rewriteEqual).sortBy(_.hashCode()) - .reduce(And), child) + val newCondition = + splitConjunctivePredicates(condition) + .map(rewriteEqual) + .sortBy(p => scala.util.hashing.MurmurHash3.seqHash(Seq(p.getClass, p))) + .reduce(And) + Filter(newCondition, child) case sample: Sample => sample.copy(seed = 0L) case Join(left, right, joinType, condition, hint) if condition.isDefined => val newCondition = - splitConjunctivePredicates(condition.get).map(rewriteEqual).sortBy(_.hashCode()) + splitConjunctivePredicates(condition.get) + .map(rewriteEqual) + .sortBy(p => scala.util.hashing.MurmurHash3.seqHash(Seq(p.getClass, p))) .reduce(And) Join(left, right, joinType, Some(newCondition), hint) } @@ -113,12 +119,15 @@ trait PlanTestBase extends PredicateHelper with SQLHelper { self: Suite => * 1. (a = b), (b = a); * 2. (a <=> b), (b <=> a). */ - private def rewriteEqual(condition: Expression): Expression = condition match { + private def rewriteEqual(condition: Expression): Expression = condition transform { case eq @ EqualTo(l: Expression, r: Expression) => - Seq(l, r).sortBy(_.hashCode()).reduce(EqualTo) + Seq(l, r) + .sortBy(p => scala.util.hashing.MurmurHash3.seqHash(Seq(p.getClass, p))) + .reduce(EqualTo) case eq @ EqualNullSafe(l: Expression, r: Expression) => - Seq(l, r).sortBy(_.hashCode()).reduce(EqualNullSafe) - case _ => condition // Don't reorder. + Seq(l, r) + .sortBy(p => scala.util.hashing.MurmurHash3.seqHash(Seq(p.getClass, p))) + .reduce(EqualNullSafe) } /** Fails the test if the two plans do not match */ From d1ae18cb2bb47ce909febfa82acaac7694827792 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 30 Apr 2019 16:31:03 +0200 Subject: [PATCH 02/12] normalize only once Change-Id: I8031a37f380d5fe5fff37488460b790dd65c0161 --- .../sql/catalyst/optimizer/expressions.scala | 2 +- .../plans/logical/QueryPlanConstraints.scala | 88 ++++++++++--------- 2 files changed, 46 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 8108d6fc9994..8beb11aba61d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -66,7 +66,7 @@ object ConstantFolding extends Rule[LogicalPlan] { object ConstraintPropagation extends Rule[LogicalPlan] with ConstraintHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case f: Filter => - val (newCondition, _) = simplifyWithConstraints(f.condition) + val newCondition = simplifyWithConstraints(f.condition) if (newCondition fastEquals f.condition) { f } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 08687bd855c7..25f54ebc1753 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -123,6 +123,14 @@ trait ConstraintHelper { case _ => Seq.empty[Attribute] } + def simplifyWithConstraints(expression: Expression): Expression = + simplify(normalize(expression))._1 + + private def normalize(expression: Expression) = expression transform { + case GreaterThan(x, y) => LessThan(y, x) + case GreaterThanOrEqual(x, y) => LessThanOrEqual(y, x) + } + /** * Traverse a condition as a tree and simplify expressions with constraints. * - On matching [[And]], recursively traverse both children, simplify child expressions with @@ -136,53 +144,47 @@ trait ConstraintHelper { * 1. Expression: optionally changed condition after traversal * 2. Seq[Expression]: propagated constraints */ - def simplifyWithConstraints(expression: Expression): (Expression, Seq[Expression]) = - expression match { - case e @ (_: LessThan | _: LessThanOrEqual | _: EqualTo | _: EqualNullSafe | _: GreaterThan | - _: GreaterThanOrEqual ) - if e.deterministic => (e, Seq(normalize(e))) - case a @ And(left, right) => - val (newLeft, leftConstraints) = simplifyWithConstraints(left) - val simplifiedRight = simplify(right, leftConstraints) - val (simplifiedNewRight, rightConstraints) = simplifyWithConstraints(simplifiedRight) - val simplifiedNewLeft = simplify(newLeft, rightConstraints) - val newAnd = if ((simplifiedNewLeft fastEquals left) && - (simplifiedNewRight fastEquals right)) { - a - } else { - And(simplifiedNewLeft, simplifiedNewRight) - } - (newAnd, leftConstraints ++ rightConstraints) - case o @ Or(left, right) => - // Ignore the EqualityPredicates from children since they are only propagated through And. - val (newLeft, _) = simplifyWithConstraints(left) - val (newRight, _) = simplifyWithConstraints(right) - val newOr = if ((newLeft fastEquals left) && (newRight fastEquals right)) { - o - } else { - Or(newLeft, newRight) - } - - (newOr, Seq.empty) - case n @ Not(child) => - // Ignore the EqualityPredicates from children since they are only propagated through And. - val (newChild, _) = simplifyWithConstraints(child) - val newNot = if (newChild fastEquals child) { - n - } else { - Not(newChild) - } - (newNot, Seq.empty) - case _ => (expression, Seq.empty) - } + private def simplify(expression: Expression): (Expression, Seq[Expression]) = expression match { + case e @ (_: LessThan | _: LessThanOrEqual | _: EqualTo | _: EqualNullSafe | _: GreaterThan | + _: GreaterThanOrEqual ) + if e.deterministic => (e, Seq(e)) + case a @ And(left, right) => + val (newLeft, leftConstraints) = simplify(left) + val simplifiedRight = simplify(right, leftConstraints) + val (simplifiedNewRight, rightConstraints) = simplify(simplifiedRight) + val simplifiedNewLeft = simplify(newLeft, rightConstraints) + val newAnd = if ((simplifiedNewLeft fastEquals left) && + (simplifiedNewRight fastEquals right)) { + a + } else { + And(simplifiedNewLeft, simplifiedNewRight) + } + (newAnd, leftConstraints ++ rightConstraints) + case o @ Or(left, right) => + // Ignore the EqualityPredicates from children since they are only propagated through And. + val (newLeft, _) = simplify(left) + val (newRight, _) = simplify(right) + val newOr = if ((newLeft fastEquals left) && (newRight fastEquals right)) { + o + } else { + Or(newLeft, newRight) + } - private def normalize(expression: Expression) = expression transform { - case GreaterThan(x, y) => LessThan(y, x) - case GreaterThanOrEqual(x, y) => LessThanOrEqual(y, x) + (newOr, Seq.empty) + case n @ Not(child) => + // Ignore the EqualityPredicates from children since they are only propagated through And. + val (newChild, _) = simplify(child) + val newNot = if (newChild fastEquals child) { + n + } else { + Not(newChild) + } + (newNot, Seq.empty) + case _ => (expression, Seq.empty) } private def simplify(expression: Expression, constraints: Seq[Expression]): Expression = - constraints.foldLeft(normalize(expression))((e, constraint) => simplify(e, constraint)) + constraints.foldLeft(expression)((e, constraint) => simplify(e, constraint)) private def planEqual(x: Expression, y: Expression) = !x.foldable && !y.foldable && x.canonicalized == y.canonicalized From 8aa8e36df1b82ad9201ff2f00758eb6e85f2855f Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 30 Apr 2019 17:35:55 +0200 Subject: [PATCH 03/12] fix null handling regarding EqualNullSafe Change-Id: I79ec5286b4762d98abaea638f8a9e5e9f7f753d0 --- .../catalyst/plans/logical/LogicalPlan.scala | 6 +- .../plans/logical/QueryPlanConstraints.scala | 59 +++++++++++-------- 2 files changed, 36 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 9bea4cfac3c3..9aa3177a9063 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -173,15 +173,15 @@ abstract class UnaryNode extends LogicalPlan { var allConstraints = child.constraints.asInstanceOf[Set[Expression]] projectList.foreach { case a @ Alias(l: Literal, _) => - allConstraints += EqualTo(a.toAttribute, l) + allConstraints += + (if (l.nullable) EqualNullSafe(a.toAttribute, l) else EqualTo(a.toAttribute, l)) case a @ Alias(e, _) => // For every alias in `projectList`, replace the reference in constraints by its attribute. allConstraints ++= allConstraints.map(_ transform { case expr: Expression if expr.semanticEquals(e) => a.toAttribute }) - allConstraints += - (if (e.foldable) EqualTo(e, a.toAttribute) else EqualNullSafe(e, a.toAttribute)) + allConstraints += EqualNullSafe(e, a.toAttribute) case _ => // Don't change. } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 25f54ebc1753..7969e3c47603 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -334,34 +334,41 @@ trait ConstraintHelper { case c EqualNullSafe d if planEqual(a, d) => EqualTo(c, d) } } - case a EqualNullSafe b => expression transformUp { - case c LessThan d if planEqual(b, d) && planEqual(a, c) => - Literal.FalseLiteral - case c LessThan d if planEqual(b, c) && planEqual(d, a) => - Literal.FalseLiteral - case c LessThan d if planEqual(a, d) && planEqual(b, c) => - Literal.FalseLiteral - case c LessThan d if planEqual(a, c) && planEqual(d, b) => - Literal.FalseLiteral + case a EqualNullSafe b => + if (b.foldable) { + expression transformUp { case c if planEqual(a, c) => b } + } else if (a.foldable) { + expression transformUp { case c if planEqual(b, c) => a } + } else { + expression transformUp { + case c LessThan d if planEqual(b, d) && planEqual(a, c) => + Literal.FalseLiteral + case c LessThan d if planEqual(b, c) && planEqual(d, a) => + Literal.FalseLiteral + case c LessThan d if planEqual(a, d) && planEqual(b, c) => + Literal.FalseLiteral + case c LessThan d if planEqual(a, c) && planEqual(d, b) => + Literal.FalseLiteral - case c LessThanOrEqual d if planEqual(b, d) && planEqual(a, c) => - EqualTo(c, d) - case c LessThanOrEqual d if planEqual(b, c) && planEqual(a, d) => - EqualTo(c, d) - case c LessThanOrEqual d if planEqual(a, d) && planEqual(b, c) => - EqualTo(c, d) - case c LessThanOrEqual d if planEqual(a, c) && planEqual(b, d) => - EqualTo(c, d) + case c LessThanOrEqual d if planEqual(b, d) && planEqual(a, c) => + EqualTo(c, d) + case c LessThanOrEqual d if planEqual(b, c) && planEqual(a, d) => + EqualTo(c, d) + case c LessThanOrEqual d if planEqual(a, d) && planEqual(b, c) => + EqualTo(c, d) + case c LessThanOrEqual d if planEqual(a, c) && planEqual(b, d) => + EqualTo(c, d) - case c EqualNullSafe d if planEqual(b, d) && planEqual(a, c) => - Literal.TrueLiteral - case c EqualNullSafe d if planEqual(b, c) && planEqual(a, d) => - Literal.TrueLiteral - case c EqualNullSafe d if planEqual(a, d) && planEqual(b, c) => - Literal.TrueLiteral - case c EqualNullSafe d if planEqual(a, c) && planEqual(b, d) => - Literal.TrueLiteral - } + case c EqualNullSafe d if planEqual(b, d) && planEqual(a, c) => + Literal.TrueLiteral + case c EqualNullSafe d if planEqual(b, c) && planEqual(a, d) => + Literal.TrueLiteral + case c EqualNullSafe d if planEqual(a, d) && planEqual(b, c) => + Literal.TrueLiteral + case c EqualNullSafe d if planEqual(a, c) && planEqual(b, d) => + Literal.TrueLiteral + } + } case _ => expression } } From 3737801a18b6e1df31e7452d153628f4bef0c837 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Fri, 3 May 2019 16:34:25 +0200 Subject: [PATCH 04/12] fix review findings --- .../sql/catalyst/optimizer/Optimizer.scala | 3 +- .../sql/catalyst/optimizer/expressions.scala | 100 +++++- .../plans/logical/QueryPlanConstraints.scala | 258 ++++++-------- .../optimizer/ConstantPropagationSuite.scala | 284 +--------------- .../optimizer/FilterReductionSuite.scala | 321 ++++++++++++++++++ 5 files changed, 519 insertions(+), 447 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterReductionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 741e6ee71b02..d6d79c2508ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -80,7 +80,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) // Constant folding and strength reduction TransposeWindow, NullPropagation, - ConstraintPropagation, + ConstantPropagation, + FilterReduction, FoldablePropagation, OptimizeIn, ConstantFolding, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 8beb11aba61d..ed3fb0a9e187 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -55,18 +55,103 @@ object ConstantFolding extends Rule[LogicalPlan] { } /** - * Substitutes expressions which can be statically narrowed by constrains. + * Substitutes [[Expression Expressions]] which can be statically evaluated with their corresponding + * value in conjunctive [[Expression Expressions]] * eg. * {{{ - * SELECT * FROM table WHERE i = 5 AND j = i + 3 => SELECT * FROM table WHERE i = 5 AND j = 8 - * SELECT * FROM table WHERE i <= 5 AND i = 5 => SELECT * FROM table WHERE i = 5 - * SELECT * FROM table WHERE i < j AND ... AND i > j => SELECT * FROM table WHERE false + * SELECT * FROM table WHERE i = 5 AND j = i + 3 => ... WHERE i = 5 AND j = 8 + * SELECT * FROM table WHERE abs(i) = 5 AND j <= abs(i) + 3 => ... WHERE i = 5 AND j <= 8 + * }}} + * + * Approach used: + * - Populate a mapping of expression => constant value by looking at all the deterministic equals + * predicates + * - Using this mapping, replace occurrence of the expressions with the corresponding constant + * values in the AND node. + */ +object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case f: Filter => + val (newCondition, _) = traverse(f.condition) + if (newCondition fastEquals f.condition) { + f + } else { + f.copy(condition = newCondition) + } + } + + /** + * Traverse a condition as a tree and replace expressions with constant values. + * - On matching [[And]], recursively traverse each children and get propagated mappings. + * If the current node is not child of another [[And]], replace all occurrences of the + * expressions with the corresponding constant values. + * - If a child of [[And]] is [[EqualTo]] or [[EqualNullSafe]], propagate the mapping + * of expression => constant. + * - On matching [[Or]] or [[Not]], recursively traverse each children, propagate empty mapping. + * - Otherwise, stop traversal and propagate empty mapping. + * @param expression expression to be traversed + * @return A tuple including: + * 1. Option[Expression]: optional changed condition after traversal + * 2. Seq[(Expression, Literal)]: propagated mapping of expression => constant + */ + private def traverse(expression: Expression): (Expression, Seq[(Expression, Literal)]) = + expression match { + case e @ EqualTo(left, right: Literal) if e.deterministic => (e, Seq((left, right))) + case e @ EqualTo(left: Literal, right) if e.deterministic => (e, Seq((right, left))) + case e @ EqualNullSafe(left, right: Literal) if e.deterministic => (e, Seq((left, right))) + case e @ EqualNullSafe(left: Literal, right) if e.deterministic => (e, Seq((right, left))) + case a @ And(left, right) => + val (newLeft, equalityPredicatesLeft) = traverse(left) + val replacedRight = replaceConstants(right, equalityPredicatesLeft) + val (replacedNewRight, equalityPredicatesRight) = traverse(replacedRight) + val replacedNewLeft = replaceConstants(newLeft, equalityPredicatesRight) + val newAnd = if ((replacedNewLeft fastEquals left) && (replacedNewRight fastEquals right)) { + a + } else { + And(replacedNewLeft, replacedNewRight) + } + (newAnd, equalityPredicatesLeft ++ equalityPredicatesRight) + case o @ Or(left, right) => + // Ignore the EqualityPredicates from children since they are only propagated through And. + val (newLeft, _) = traverse(left) + val (newRight, _) = traverse(right) + val newOr = if ((newLeft fastEquals left) && (newRight fastEquals right)) { + o + } else { + Or(newLeft, newRight) + } + + (newOr, Seq.empty) + case n @ Not(child) => + // Ignore the EqualityPredicates from children since they are only propagated through And. + val (newChild, _) = traverse(child) + val newNot = if (newChild fastEquals child) { + n + } else { + Not(newChild) + } + (newNot, Seq.empty) + case _ => (expression, Seq.empty) + } + + private def replaceConstants(expression: Expression, constants: Seq[(Expression, Literal)]) = + constants.foldLeft(expression)((e, constant) => e transformUp { + case e if e.canonicalized == constant._1.canonicalized => constant._2 + }) +} + +/** + * Substitutes expressions which can be statically reduced by constraints. + * eg. + * {{{ + * SELECT * FROM table WHERE i <= 5 AND i = 5 => ... WHERE i = 5 + * SELECT * FROM table WHERE i < j AND ... AND i > j => ... WHERE false * }}} */ -object ConstraintPropagation extends Rule[LogicalPlan] with ConstraintHelper { +object FilterReduction extends Rule[LogicalPlan] with ConstraintHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case f: Filter => - val newCondition = simplifyWithConstraints(f.condition) + val newCondition = normalizeAndReduceWithConstraints(f.condition) if (newCondition fastEquals f.condition) { f } else { @@ -313,7 +398,8 @@ object SimplifyBinaryComparison extends Rule[LogicalPlan] with PredicateHelper { case q: LogicalPlan => q transformExpressionsUp { // True with equality case a EqualNullSafe b if a.semanticEquals(b) => TrueLiteral - case a EqualNullSafe b if a.foldable || b.foldable => EqualTo(a, b) +// case a EqualNullSafe b if a.foldable && !a.nullable || b.foldable && !b.nullable => +// EqualTo(a, b) case a EqualTo b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral case a GreaterThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 7969e3c47603..2014ce1dd67a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -123,8 +123,8 @@ trait ConstraintHelper { case _ => Seq.empty[Attribute] } - def simplifyWithConstraints(expression: Expression): Expression = - simplify(normalize(expression))._1 + def normalizeAndReduceWithConstraints(expression: Expression): Expression = + reduceWithConstraints(normalize(expression))._1 private def normalize(expression: Expression) = expression transform { case GreaterThan(x, y) => LessThan(y, x) @@ -133,58 +133,59 @@ trait ConstraintHelper { /** * Traverse a condition as a tree and simplify expressions with constraints. + * - This functions assumes that the plan has been normalized using [[normalize()]] * - On matching [[And]], recursively traverse both children, simplify child expressions with * propagated constraints from sibling and propagate up union of constraints. * - If a child of [[And]] is [[LessThan]], [[LessThanOrEqual]], [[EqualTo]], [[EqualNullSafe]], - * [[GreaterThan]] or [[GreaterThanOrEqual]] propagate the constraint. + * propagate the constraint. * - On matching [[Or]] or [[Not]], recursively traverse each children, propagate no constraints. * - Otherwise, stop traversal and propagate no constraints. * @param expression expression to be traversed * @return A tuple including: - * 1. Expression: optionally changed condition after traversal + * 1. Expression: optionally changed expression after traversal * 2. Seq[Expression]: propagated constraints */ - private def simplify(expression: Expression): (Expression, Seq[Expression]) = expression match { - case e @ (_: LessThan | _: LessThanOrEqual | _: EqualTo | _: EqualNullSafe | _: GreaterThan | - _: GreaterThanOrEqual ) - if e.deterministic => (e, Seq(e)) - case a @ And(left, right) => - val (newLeft, leftConstraints) = simplify(left) - val simplifiedRight = simplify(right, leftConstraints) - val (simplifiedNewRight, rightConstraints) = simplify(simplifiedRight) - val simplifiedNewLeft = simplify(newLeft, rightConstraints) - val newAnd = if ((simplifiedNewLeft fastEquals left) && - (simplifiedNewRight fastEquals right)) { - a - } else { - And(simplifiedNewLeft, simplifiedNewRight) - } - (newAnd, leftConstraints ++ rightConstraints) - case o @ Or(left, right) => - // Ignore the EqualityPredicates from children since they are only propagated through And. - val (newLeft, _) = simplify(left) - val (newRight, _) = simplify(right) - val newOr = if ((newLeft fastEquals left) && (newRight fastEquals right)) { - o - } else { - Or(newLeft, newRight) - } + private def reduceWithConstraints(expression: Expression): (Expression, Seq[Expression]) = + expression match { + case e @ (_: LessThan | _: LessThanOrEqual | _: EqualTo | _: EqualNullSafe) + if e.deterministic => (e, Seq(e)) + case a @ And(left, right) => + val (newLeft, leftConstraints) = reduceWithConstraints(left) + val simplifiedRight = reduceWithConstraints(right, leftConstraints) + val (simplifiedNewRight, rightConstraints) = reduceWithConstraints(simplifiedRight) + val simplifiedNewLeft = reduceWithConstraints(newLeft, rightConstraints) + val newAnd = if ((simplifiedNewLeft fastEquals left) && + (simplifiedNewRight fastEquals right)) { + a + } else { + And(simplifiedNewLeft, simplifiedNewRight) + } + (newAnd, leftConstraints ++ rightConstraints) + case o @ Or(left, right) => + // Ignore the EqualityPredicates from children since they are only propagated through And. + val (newLeft, _) = reduceWithConstraints(left) + val (newRight, _) = reduceWithConstraints(right) + val newOr = if ((newLeft fastEquals left) && (newRight fastEquals right)) { + o + } else { + Or(newLeft, newRight) + } - (newOr, Seq.empty) - case n @ Not(child) => - // Ignore the EqualityPredicates from children since they are only propagated through And. - val (newChild, _) = simplify(child) - val newNot = if (newChild fastEquals child) { - n - } else { - Not(newChild) - } - (newNot, Seq.empty) - case _ => (expression, Seq.empty) - } + (newOr, Seq.empty) + case n @ Not(child) => + // Ignore the EqualityPredicates from children since they are only propagated through And. + val (newChild, _) = reduceWithConstraints(child) + val newNot = if (newChild fastEquals child) { + n + } else { + Not(newChild) + } + (newNot, Seq.empty) + case _ => (expression, Seq.empty) + } - private def simplify(expression: Expression, constraints: Seq[Expression]): Expression = - constraints.foldLeft(expression)((e, constraint) => simplify(e, constraint)) + private def reduceWithConstraints(expression: Expression, constraints: Seq[Expression]) = + constraints.foldLeft(expression)((e, constraint) => reduceWithConstraint(e, constraint)) private def planEqual(x: Expression, y: Expression) = !x.foldable && !y.foldable && x.canonicalized == y.canonicalized @@ -198,7 +199,7 @@ trait ConstraintHelper { private def valueLessThanOrEqual(x: Expression, y: Expression) = x.foldable && y.foldable && LessThanOrEqual(x, y).eval(EmptyRow).asInstanceOf[Boolean] - private def simplify(expression: Expression, constraint: Expression): Expression = + private def reduceWithConstraint(expression: Expression, constraint: Expression): Expression = constraint match { case a LessThan b => expression transformUp { case c LessThan d if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => @@ -223,28 +224,19 @@ trait ConstraintHelper { if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => Literal.FalseLiteral - case c EqualTo d if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => - Literal.FalseLiteral - case c EqualTo d if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => - Literal.FalseLiteral - case c EqualTo d if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => - Literal.FalseLiteral - case c EqualTo d if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => - Literal.FalseLiteral - - case c EqualNullSafe d if planEqual(b, d) && planEqual(a, c) => - Literal.FalseLiteral - case c EqualNullSafe d if planEqual(b, c) && planEqual(a, d) => - Literal.FalseLiteral - case c EqualNullSafe d if planEqual(a, c) && planEqual(b, d) => - Literal.FalseLiteral - case c EqualNullSafe d if planEqual(a, d) && planEqual(b, c) => - Literal.FalseLiteral - - case c EqualNullSafe d if planEqual(b, d) => EqualTo(c, d) - case c EqualNullSafe d if planEqual(b, c) => EqualTo(c, d) - case c EqualNullSafe d if planEqual(a, c) => EqualTo(c, d) - case c EqualNullSafe d if planEqual(a, d) => EqualTo(c, d) + case c EqualTo d if planEqual(b, d) && planEqual(a, c) => Literal.FalseLiteral + case c EqualTo d if planEqual(b, c) && planEqual(a, d) => Literal.FalseLiteral + case c EqualTo d if planEqual(a, c) && planEqual(b, d) => Literal.FalseLiteral + case c EqualTo d if planEqual(a, d) && planEqual(b, c) => Literal.FalseLiteral + + case c EqualNullSafe d if planEqual(b, d) => + if (planEqual(a, c)) Literal.FalseLiteral else EqualTo(c, d) + case c EqualNullSafe d if planEqual(b, c) => + if (planEqual(a, d)) Literal.FalseLiteral else EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, c) => + if (planEqual(b, d)) Literal.FalseLiteral else EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, d) => + if (planEqual(b, c)) Literal.FalseLiteral else EqualTo(c, d) } case a LessThanOrEqual b => expression transformUp { case c LessThan d if planEqual(b, d) && valueLessThan(c, a) => @@ -271,104 +263,52 @@ trait ConstraintHelper { case c LessThanOrEqual d if planEqual(a, d) && (planEqual(b, c) || valueEqual(b, c)) => EqualTo(c, d) - case c EqualTo d if planEqual(b, d) && valueLessThan(c, a) => - Literal.FalseLiteral - case c EqualTo d if planEqual(b, c) && valueLessThan(d, a) => - Literal.FalseLiteral - case c EqualTo d if planEqual(a, c) && valueLessThan(b, d) => - Literal.FalseLiteral - case c EqualTo d if planEqual(a, d) && valueLessThan(b, c) => - Literal.FalseLiteral - case c EqualNullSafe d if planEqual(b, d) => EqualTo(c, d) case c EqualNullSafe d if planEqual(b, c) => EqualTo(c, d) case c EqualNullSafe d if planEqual(a, c) => EqualTo(c, d) case c EqualNullSafe d if planEqual(a, d) => EqualTo(c, d) } - case a EqualTo b => - if (b.foldable) { - expression transformUp { case c if planEqual(a, c) => b } - } else if (a.foldable) { - expression transformUp { case c if planEqual(b, c) => a } - } else { - expression transformUp { - case c LessThan d if planEqual(b, d) && planEqual(a, c) => - Literal.FalseLiteral - case c LessThan d if planEqual(b, c) && planEqual(a, d) => - Literal.FalseLiteral - case c LessThan d if planEqual(a, d) && planEqual(b, c) => - Literal.FalseLiteral - case c LessThan d if planEqual(a, c) && planEqual(b, d) => - Literal.FalseLiteral - - case c LessThanOrEqual d if planEqual(b, d) && planEqual(a, c) => - Literal.TrueLiteral - case c LessThanOrEqual d if planEqual(b, c) && planEqual(a, d) => - Literal.TrueLiteral - case c LessThanOrEqual d if planEqual(a, d) && planEqual(b, c) => - Literal.TrueLiteral - case c LessThanOrEqual d if planEqual(a, c) && planEqual(b, d) => - Literal.TrueLiteral - - case c EqualTo d if planEqual(b, d) && planEqual(a, c) => - Literal.TrueLiteral - case c EqualTo d if planEqual(b, c) && planEqual(a, d) => - Literal.TrueLiteral - case c EqualTo d if planEqual(a, d) && planEqual(b, c) => - Literal.TrueLiteral - case c EqualTo d if planEqual(a, c) && planEqual(b, d) => - Literal.TrueLiteral - - case c EqualNullSafe d if planEqual(b, d) && planEqual(a, c) => - Literal.TrueLiteral - case c EqualNullSafe d if planEqual(b, c) && planEqual(a, d) => - Literal.TrueLiteral - case c EqualNullSafe d if planEqual(a, d) && planEqual(b, c) => - Literal.TrueLiteral - case c EqualNullSafe d if planEqual(a, c) && planEqual(b, d) => - Literal.TrueLiteral - - case c EqualNullSafe d if planEqual(b, d) => EqualTo(c, d) - case c EqualNullSafe d if planEqual(b, c) => EqualTo(c, d) - case c EqualNullSafe d if planEqual(a, c) => EqualTo(c, d) - case c EqualNullSafe d if planEqual(a, d) => EqualTo(c, d) - } - } - case a EqualNullSafe b => - if (b.foldable) { - expression transformUp { case c if planEqual(a, c) => b } - } else if (a.foldable) { - expression transformUp { case c if planEqual(b, c) => a } - } else { - expression transformUp { - case c LessThan d if planEqual(b, d) && planEqual(a, c) => - Literal.FalseLiteral - case c LessThan d if planEqual(b, c) && planEqual(d, a) => - Literal.FalseLiteral - case c LessThan d if planEqual(a, d) && planEqual(b, c) => - Literal.FalseLiteral - case c LessThan d if planEqual(a, c) && planEqual(d, b) => - Literal.FalseLiteral - - case c LessThanOrEqual d if planEqual(b, d) && planEqual(a, c) => - EqualTo(c, d) - case c LessThanOrEqual d if planEqual(b, c) && planEqual(a, d) => - EqualTo(c, d) - case c LessThanOrEqual d if planEqual(a, d) && planEqual(b, c) => - EqualTo(c, d) - case c LessThanOrEqual d if planEqual(a, c) && planEqual(b, d) => - EqualTo(c, d) - - case c EqualNullSafe d if planEqual(b, d) && planEqual(a, c) => - Literal.TrueLiteral - case c EqualNullSafe d if planEqual(b, c) && planEqual(a, d) => - Literal.TrueLiteral - case c EqualNullSafe d if planEqual(a, d) && planEqual(b, c) => - Literal.TrueLiteral - case c EqualNullSafe d if planEqual(a, c) && planEqual(b, d) => - Literal.TrueLiteral - } - } + case a EqualTo b => expression transformUp { + case c LessThan d if planEqual(b, d) && planEqual(a, c) => Literal.FalseLiteral + case c LessThan d if planEqual(b, c) && planEqual(a, d) => Literal.FalseLiteral + case c LessThan d if planEqual(a, d) && planEqual(b, c) => Literal.FalseLiteral + case c LessThan d if planEqual(a, c) && planEqual(b, d) => Literal.FalseLiteral + + case c LessThanOrEqual d if planEqual(b, d) && planEqual(a, c) => Literal.TrueLiteral + case c LessThanOrEqual d if planEqual(b, c) && planEqual(a, d) => Literal.TrueLiteral + case c LessThanOrEqual d if planEqual(a, d) && planEqual(b, c) => Literal.TrueLiteral + case c LessThanOrEqual d if planEqual(a, c) && planEqual(b, d) => Literal.TrueLiteral + + case c EqualTo d if planEqual(b, d) && planEqual(a, c) => Literal.TrueLiteral + case c EqualTo d if planEqual(b, c) && planEqual(a, d) => Literal.TrueLiteral + case c EqualTo d if planEqual(a, d) && planEqual(b, c) => Literal.TrueLiteral + case c EqualTo d if planEqual(a, c) && planEqual(b, d) => Literal.TrueLiteral + + case c EqualNullSafe d if planEqual(b, d) => + if (planEqual(a, c)) Literal.TrueLiteral else EqualTo(c, d) + case c EqualNullSafe d if planEqual(b, c) => + if (planEqual(a, d)) Literal.TrueLiteral else EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, d) => + if (planEqual(b, c)) Literal.TrueLiteral else EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, c) => + if (planEqual(b, d)) Literal.TrueLiteral else EqualTo(c, d) + } + case a EqualNullSafe b => expression transformUp { + case c LessThan d if planEqual(b, d) && planEqual(a, c) => Literal.FalseLiteral + case c LessThan d if planEqual(b, c) && planEqual(d, a) => Literal.FalseLiteral + case c LessThan d if planEqual(a, d) && planEqual(b, c) => Literal.FalseLiteral + case c LessThan d if planEqual(a, c) && planEqual(d, b) => Literal.FalseLiteral + + case c LessThanOrEqual d if planEqual(b, d) && planEqual(a, c) => EqualTo(c, d) + case c LessThanOrEqual d if planEqual(b, c) && planEqual(a, d) => EqualTo(c, d) + case c LessThanOrEqual d if planEqual(a, d) && planEqual(b, c) => EqualTo(c, d) + case c LessThanOrEqual d if planEqual(a, c) && planEqual(b, d) => EqualTo(c, d) + + case c EqualNullSafe d if planEqual(b, d) && planEqual(a, c) => Literal.TrueLiteral + case c EqualNullSafe d if planEqual(b, c) && planEqual(a, d) => Literal.TrueLiteral + case c EqualNullSafe d if planEqual(a, d) && planEqual(b, c) => Literal.TrueLiteral + case c EqualNullSafe d if planEqual(a, c) && planEqual(b, d) => Literal.TrueLiteral + } case _ => expression } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala index 2dce72505445..bbf390b7ddf6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -36,28 +35,12 @@ class ConstantPropagationSuite extends PlanTest { Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("ConstantPropagation", FixedPoint(10), - ConstraintPropagation, + ConstantPropagation, ConstantFolding, - BooleanSimplification, - SimplifyBinaryComparison, - PruneFilters) :: Nil + BooleanSimplification) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'x.boolean, 'y.boolean) - - val data = { - val intElements = Seq(null, 1, 2, 3) - val booleanElements = Seq(null, true, false) - for { - a <- intElements - b <- intElements - c <- intElements - x <- booleanElements - y <- booleanElements - } yield (a, b, c, x, y) - } - - val testRelationWithData = LocalRelation.fromExternalRows(testRelation.output, data.map(Row(_))) + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) private val columnA = 'a private val columnB = 'b @@ -171,267 +154,8 @@ class ConstantPropagationSuite extends PlanTest { .where( columnA === Literal(1) && columnA === Literal(2) && columnB === Add(columnA, Literal(3))) - val correctAnswer = testRelation.analyze + val correctAnswer = testRelation.where(Literal.FalseLiteral).analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } - - private def testPropagation( - input: Expression, - expectEmptyRelation: Boolean, - expectedConstraints: Seq[Expression] = Seq.empty) = { - val originalQuery = testRelationWithData.where(input).analyze - val optimized = Optimize.execute(originalQuery) - val correctAnswer = if (expectEmptyRelation) { - testRelation - } else { - testRelationWithData.where(expectedConstraints.reduce(And)).analyze - } - comparePlans(optimized, correctAnswer) - } - - test("Enhanced constraint propagation") { - testPropagation('a < 2 && Literal(2) < 'a, true) - testPropagation('a < 2 && Literal(2) <= 'a, true) - testPropagation('a < 2 && Literal(2) === 'a, true) - testPropagation('a < 2 && Literal(2) <=> 'a, true) - testPropagation('a < 2 && Literal(2) >= 'a, false, Seq('a < 2)) - testPropagation('a < 2 && Literal(2) > 'a, false, Seq('a < 2)) - testPropagation('a <= 2 && Literal(2) < 'a, true) - testPropagation('a <= 2 && Literal(2) <= 'a, false, Seq('a === 2)) - testPropagation('a <= 2 && Literal(2) === 'a, false, Seq('a === 2)) - testPropagation('a <= 2 && Literal(2) <=> 'a, false, Seq('a === 2)) - testPropagation('a <= 2 && Literal(2) >= 'a, false, Seq('a <= 2)) - testPropagation('a <= 2 && Literal(2) > 'a, false, Seq('a < 2)) - testPropagation('a === 2 && Literal(2) < 'a, true) - testPropagation('a === 2 && Literal(2) <= 'a, false, Seq('a === 2)) - testPropagation('a === 2 && Literal(2) === 'a, false, Seq('a === 2)) - testPropagation('a === 2 && Literal(2) <=> 'a, false, Seq('a === 2)) - testPropagation('a === 2 && Literal(2) >= 'a, false, Seq('a === 2)) - testPropagation('a === 2 && Literal(2) > 'a, true) - testPropagation('a <=> 2 && Literal(2) < 'a, true) - testPropagation('a <=> 2 && Literal(2) <= 'a, false, Seq('a === 2)) - testPropagation('a <=> 2 && Literal(2) === 'a, false, Seq('a === 2)) - testPropagation('a <=> 2 && Literal(2) <=> 'a, false, Seq('a === 2)) - testPropagation('a <=> 2 && Literal(2) >= 'a, false, Seq('a === 2)) - testPropagation('a <=> 2 && Literal(2) > 'a, true) - testPropagation('a >= 2 && Literal(2) < 'a, false, Seq(Literal(2) < 'a)) - testPropagation('a >= 2 && Literal(2) <= 'a, false, Seq(Literal(2) <= 'a)) - testPropagation('a >= 2 && Literal(2) === 'a, false, Seq('a === 2)) - testPropagation('a >= 2 && Literal(2) <=> 'a, false, Seq('a === 2)) - testPropagation('a >= 2 && Literal(2) >= 'a, false, Seq('a === 2)) - testPropagation('a >= 2 && Literal(2) > 'a, true) - testPropagation('a > 2 && Literal(2) < 'a, false, Seq(Literal(2) < 'a)) - testPropagation('a > 2 && Literal(2) <= 'a, false, Seq(Literal(2) < 'a)) - testPropagation('a > 2 && Literal(2) === 'a, true) - testPropagation('a > 2 && Literal(2) <=> 'a, true) - testPropagation('a > 2 && Literal(2) >= 'a, true) - testPropagation('a > 2 && Literal(2) > 'a, true) - - testPropagation('a < 2 && Literal(3) < 'a, true) - testPropagation('a < 2 && Literal(3) <= 'a, true) - testPropagation('a < 2 && Literal(3) === 'a, true) - testPropagation('a < 2 && Literal(3) <=> 'a, true) - testPropagation('a < 2 && Literal(3) >= 'a, false, Seq('a < 2)) - testPropagation('a < 2 && Literal(3) > 'a, false, Seq('a < 2)) - testPropagation('a <= 2 && Literal(3) < 'a, true) - testPropagation('a <= 2 && Literal(3) <= 'a, true) - testPropagation('a <= 2 && Literal(3) === 'a, true) - testPropagation('a <= 2 && Literal(3) <=> 'a, true) - testPropagation('a <= 2 && Literal(3) >= 'a, false, Seq('a <= 2)) - testPropagation('a <= 2 && Literal(3) > 'a, false, Seq('a <= 2)) - testPropagation('a === 2 && Literal(3) < 'a, true) - testPropagation('a === 2 && Literal(3) <= 'a, true) - testPropagation('a === 2 && Literal(3) === 'a, true) - testPropagation('a === 2 && Literal(3) <=> 'a, true) - testPropagation('a === 2 && Literal(3) >= 'a, false, Seq('a === 2)) - testPropagation('a === 2 && Literal(3) > 'a, false, Seq('a === 2)) - testPropagation('a <=> 2 && Literal(3) < 'a, true) - testPropagation('a <=> 2 && Literal(3) <= 'a, true) - testPropagation('a <=> 2 && Literal(3) === 'a, true) - testPropagation('a <=> 2 && Literal(3) <=> 'a, true) - testPropagation('a <=> 2 && Literal(3) >= 'a, false, Seq('a === 2)) - testPropagation('a <=> 2 && Literal(3) > 'a, false, Seq('a === 2)) - testPropagation('a >= 2 && Literal(3) < 'a, false, Seq(Literal(3) < 'a)) - testPropagation('a >= 2 && Literal(3) <= 'a, false, Seq(Literal(3) <= 'a)) - testPropagation('a >= 2 && Literal(3) === 'a, false, Seq(Literal(3) === 'a)) - testPropagation('a >= 2 && Literal(3) <=> 'a, false, Seq(Literal(3) === 'a)) - testPropagation('a >= 2 && Literal(3) >= 'a, false, Seq(Literal(2) <= 'a, 'a <= Literal(3))) - testPropagation('a >= 2 && Literal(3) > 'a, false, Seq(Literal(2) <= 'a, 'a < Literal(3))) - testPropagation('a > 2 && Literal(3) < 'a, false, Seq(Literal(3) < 'a)) - testPropagation('a > 2 && Literal(3) <= 'a, false, Seq(Literal(3) <= 'a)) - testPropagation('a > 2 && Literal(3) === 'a, false, Seq(Literal(3) === 'a)) - testPropagation('a > 2 && Literal(3) <=> 'a, false, Seq(Literal(3) === 'a)) - testPropagation('a > 2 && Literal(3) >= 'a, false, Seq(Literal(2) < 'a, 'a <= Literal(3))) - testPropagation('a > 2 && Literal(3) > 'a, false, Seq(Literal(2) < 'a, 'a < Literal(3))) - - testPropagation('a < 3 && Literal(2) < 'a, false, Seq('a < 3 && Literal(2) < 'a)) - testPropagation('a < 3 && Literal(2) <= 'a, false, Seq('a < 3 && Literal(2) <= 'a)) - testPropagation('a < 3 && Literal(2) === 'a, false, Seq('a === 2)) - testPropagation('a < 3 && Literal(2) <=> 'a, false, Seq('a === 2)) - testPropagation('a < 3 && Literal(2) >= 'a, false, Seq('a <= 2)) - testPropagation('a < 3 && Literal(2) > 'a, false, Seq('a < 2)) - testPropagation('a <= 3 && Literal(2) < 'a, false, Seq('a <= 3 && Literal(2) < 'a)) - testPropagation('a <= 3 && Literal(2) <= 'a, false, Seq('a <= 3 && Literal(2) <= 'a)) - testPropagation('a <= 3 && Literal(2) === 'a, false, Seq('a === 2)) - testPropagation('a <= 3 && Literal(2) <=> 'a, false, Seq('a === 2)) - testPropagation('a <= 3 && Literal(2) >= 'a, false, Seq('a <= 2)) - testPropagation('a <= 3 && Literal(2) > 'a, false, Seq('a < 2)) - testPropagation('a === 3 && Literal(2) < 'a, false, Seq('a === 3)) - testPropagation('a === 3 && Literal(2) <= 'a, false, Seq('a === 3)) - testPropagation('a === 3 && Literal(2) === 'a, true) - testPropagation('a === 3 && Literal(2) <=> 'a, true) - testPropagation('a === 3 && Literal(2) >= 'a, true) - testPropagation('a === 3 && Literal(2) > 'a, true) - testPropagation('a <=> 3 && Literal(2) < 'a, false, Seq('a === 3)) - testPropagation('a <=> 3 && Literal(2) <= 'a, false, Seq('a === 3)) - testPropagation('a <=> 3 && Literal(2) === 'a, true) - testPropagation('a <=> 3 && Literal(2) <=> 'a, true) - testPropagation('a <=> 3 && Literal(2) >= 'a, true) - testPropagation('a <=> 3 && Literal(2) > 'a, true) - testPropagation('a >= 3 && Literal(2) < 'a, false, Seq(Literal(3) <= 'a)) - testPropagation('a >= 3 && Literal(2) <= 'a, false, Seq(Literal(3) <= 'a)) - testPropagation('a >= 3 && Literal(2) === 'a, true) - testPropagation('a >= 3 && Literal(2) <=> 'a, true) - testPropagation('a >= 3 && Literal(2) >= 'a, true) - testPropagation('a >= 3 && Literal(2) > 'a, true) - testPropagation('a > 3 && Literal(2) < 'a, false, Seq(Literal(3) < 'a)) - testPropagation('a > 3 && Literal(2) <= 'a, false, Seq(Literal(3) < 'a)) - testPropagation('a > 3 && Literal(2) === 'a, true) - testPropagation('a > 3 && Literal(2) <=> 'a, true) - testPropagation('a > 3 && Literal(2) >= 'a, true) - testPropagation('a > 3 && Literal(2) > 'a, true) - - testPropagation('a < 'b && 'b < 'a, true) - testPropagation('a < 'b && 'b <= 'a, true) - testPropagation('a < 'b && 'b === 'a, true) - testPropagation('a < 'b && 'b <=> 'a, true) - testPropagation('a < 'b && 'b >= 'a, false, Seq('a < 'b)) - testPropagation('a < 'b && 'b > 'a, false, Seq('a < 'b)) - testPropagation('a <= 'b && 'b < 'a, true) - testPropagation('a <= 'b && 'b <= 'a, false, Seq('a === 'b)) - testPropagation('a <= 'b && 'b === 'a, false, Seq('a === 'b)) - testPropagation('a <= 'b && 'b <=> 'a, false, Seq('a === 'b)) - testPropagation('a <= 'b && 'b >= 'a, false, Seq('a <= 'b)) - testPropagation('a <= 'b && 'b > 'a, false, Seq('a < 'b)) - testPropagation('a === 'b && 'b < 'a, true) - testPropagation('a === 'b && 'b <= 'a, false, Seq('a === 'b)) - testPropagation('a === 'b && 'b === 'a, false, Seq('a === 'b)) - testPropagation('a === 'b && 'b <=> 'a, false, Seq('a === 'b)) - testPropagation('a === 'b && 'b >= 'a, false, Seq('a === 'b)) - testPropagation('a === 'b && 'b > 'a, true) - testPropagation('a <=> 'b && 'b < 'a, true) - testPropagation('a <=> 'b && 'b <= 'a, false, Seq('a === 'b)) - testPropagation('a <=> 'b && 'b === 'a, false, Seq('a === 'b)) - testPropagation('a <=> 'b && 'b <=> 'a, false, Seq('a <=> 'b)) - testPropagation('a <=> 'b && 'b >= 'a, false, Seq('a === 'b)) - testPropagation('a <=> 'b && 'b > 'a, true) - testPropagation('a >= 'b && 'b < 'a, false, Seq('b < 'a)) - testPropagation('a >= 'b && 'b <= 'a, false, Seq('b <= 'a)) - testPropagation('a >= 'b && 'b === 'a, false, Seq('a === 'b)) - testPropagation('a >= 'b && 'b <=> 'a, false, Seq('a === 'b)) - testPropagation('a >= 'b && 'b >= 'a, false, Seq('a === 'b)) - testPropagation('a >= 'b && 'b > 'a, true) - testPropagation('a > 'b && 'b < 'a, false, Seq('b < 'a)) - testPropagation('a > 'b && 'b <= 'a, false, Seq('b < 'a)) - testPropagation('a > 'b && 'b === 'a, true) - testPropagation('a > 'b && 'b <=> 'a, true) - testPropagation('a > 'b && 'b >= 'a, true) - testPropagation('a > 'b && 'b > 'a, true) - - testPropagation('a < abs('b) && abs('b) < 'a, true) - testPropagation('a < abs('b) && abs('b) <= 'a, true) - testPropagation('a < abs('b) && abs('b) === 'a, true) - testPropagation('a < abs('b) && abs('b) <=> 'a, true) - testPropagation('a < abs('b) && abs('b) >= 'a, false, Seq('a < abs('b))) - testPropagation('a < abs('b) && abs('b) > 'a, false, Seq('a < abs('b))) - testPropagation('a <= abs('b) && abs('b) < 'a, true) - testPropagation('a <= abs('b) && abs('b) <= 'a, false, Seq('a === abs('b))) - testPropagation('a <= abs('b) && abs('b) === 'a, false, Seq('a === abs('b))) - testPropagation('a <= abs('b) && abs('b) <=> 'a, false, Seq('a === abs('b))) - testPropagation('a <= abs('b) && abs('b) >= 'a, false, Seq('a <= abs('b))) - testPropagation('a <= abs('b) && abs('b) > 'a, false, Seq('a < abs('b))) - testPropagation('a === abs('b) && abs('b) < 'a, true) - testPropagation('a === abs('b) && abs('b) <= 'a, false, Seq('a === abs('b))) - testPropagation('a === abs('b) && abs('b) === 'a, false, Seq('a === abs('b))) - testPropagation('a === abs('b) && abs('b) <=> 'a, false, Seq('a === abs('b))) - testPropagation('a === abs('b) && abs('b) >= 'a, false, Seq('a === abs('b))) - testPropagation('a === abs('b) && abs('b) > 'a, true) - testPropagation('a <=> abs('b) && abs('b) < 'a, true) - testPropagation('a <=> abs('b) && abs('b) <= 'a, false, Seq('a === abs('b))) - testPropagation('a <=> abs('b) && abs('b) === 'a, false, Seq('a === abs('b))) - testPropagation('a <=> abs('b) && abs('b) <=> 'a, false, Seq('a <=> abs('b))) - testPropagation('a <=> abs('b) && abs('b) >= 'a, false, Seq('a === abs('b))) - testPropagation('a <=> abs('b) && abs('b) > 'a, true) - testPropagation('a >= abs('b) && abs('b) < 'a, false, Seq(abs('b) < 'a)) - testPropagation('a >= abs('b) && abs('b) <= 'a, false, Seq(abs('b) <= 'a)) - testPropagation('a >= abs('b) && abs('b) === 'a, false, Seq('a === abs('b))) - testPropagation('a >= abs('b) && abs('b) <=> 'a, false, Seq('a === abs('b))) - testPropagation('a >= abs('b) && abs('b) >= 'a, false, Seq('a === abs('b))) - testPropagation('a >= abs('b) && abs('b) > 'a, true) - testPropagation('a > abs('b) && abs('b) < 'a, false, Seq(abs('b) < 'a)) - testPropagation('a > abs('b) && abs('b) <= 'a, false, Seq(abs('b) < 'a)) - testPropagation('a > abs('b) && abs('b) === 'a, true) - testPropagation('a > abs('b) && abs('b) <=> 'a, true) - testPropagation('a > abs('b) && abs('b) >= 'a, true) - testPropagation('a > abs('b) && abs('b) > 'a, true) - - testPropagation(('x || 'a < abs('b) || 'y) && abs('b) < 'a, false, Seq('x || 'y, abs('b) < 'a)) - testPropagation(('x || 'a < abs('b) || 'y) && abs('b) <= 'a, false, - Seq('x || 'y, abs('b) <= 'a)) - testPropagation(('x || 'a < abs('b) || 'y) && abs('b) === 'a, false, - Seq('x || 'y, abs('b) === 'a)) - testPropagation(('x || 'a < abs('b) || 'y) && abs('b) <=> 'a, false, - Seq('x || 'y, abs('b) <=> 'a)) - testPropagation(('x || 'a < abs('b) || 'y) && abs('b) >= 'a, false, - Seq('x || 'a < abs('b) || 'y, 'a <= abs('b))) - testPropagation(('x || 'a < abs('b) || 'y) && abs('b) > 'a, false, Seq('a < abs('b))) - testPropagation(('x || 'a <= abs('b) || 'y) && abs('b) < 'a, false, Seq('x || 'y, abs('b) < 'a)) - testPropagation(('x || 'a <= abs('b) || 'y) && abs('b) <= 'a, false, - Seq('x || 'a === abs('b) || 'y, abs('b) <= 'a)) - testPropagation(('x || 'a <= abs('b) || 'y) && abs('b) === 'a, false, Seq(abs('b) === 'a)) - testPropagation(('x || 'a <= abs('b) || 'y) && abs('b) <=> 'a, false, - Seq('x || 'a === abs('b) || 'y, abs('b) <=> 'a)) - testPropagation(('x || 'a <= abs('b) || 'y) && abs('b) >= 'a, false, Seq('a <= abs('b))) - testPropagation(('x || 'a <= abs('b) || 'y) && abs('b) > 'a, false, Seq('a < abs('b))) - testPropagation(('x || 'a === abs('b) || 'y) && abs('b) < 'a, false, - Seq('x || 'y, abs('b) < 'a)) - testPropagation(('x || 'a === abs('b) || 'y) && abs('b) <= 'a, false, - Seq('x || 'a === abs('b) || 'y, abs('b) <= 'a)) - testPropagation(('x || 'a === abs('b) || 'y) && abs('b) === 'a, false, Seq(abs('b) === 'a)) - testPropagation(('x || 'a === abs('b) || 'y) && abs('b) <=> 'a, false, - Seq('x || 'a === abs('b) || 'y, abs('b) <=> 'a)) - testPropagation(('x || 'a === abs('b) || 'y) && abs('b) >= 'a, false, - Seq('x || 'a === abs('b) || 'y, 'a <= abs('b))) - testPropagation(('x || 'a === abs('b) || 'y) && abs('b) > 'a, false, - Seq('x || 'y, 'a < abs('b))) - testPropagation(('x || 'a <=> abs('b) || 'y) && abs('b) < 'a, false, - Seq('x || 'y, abs('b) < 'a)) - testPropagation(('x || 'a <=> abs('b) || 'y) && abs('b) <= 'a, false, - Seq('x || 'a === abs('b) || 'y, abs('b) <= 'a)) - testPropagation(('x || 'a <=> abs('b) || 'y) && abs('b) === 'a, false, Seq(abs('b) === 'a)) - testPropagation(('x || 'a <=> abs('b) || 'y) && abs('b) <=> 'a, false, Seq(abs('b) <=> 'a)) - testPropagation(('x || 'a <=> abs('b) || 'y) && abs('b) >= 'a, false, - Seq('x || 'a === abs('b) || 'y, 'a <= abs('b))) - testPropagation(('x || 'a <=> abs('b) || 'y) && abs('b) > 'a, false, - Seq('x || 'y, 'a < abs('b))) - testPropagation(('x || 'a >= abs('b) || 'y) && abs('b) < 'a, false, Seq(abs('b) < 'a)) - testPropagation(('x || 'a >= abs('b) || 'y) && abs('b) <= 'a, false, Seq(abs('b) <= 'a)) - testPropagation(('x || 'a >= abs('b) || 'y) && abs('b) === 'a, false, Seq(abs('b) === 'a)) - testPropagation(('x || 'a >= abs('b) || 'y) && abs('b) <=> 'a, false, - Seq('x || 'a === abs('b) || 'y, abs('b) <=> 'a)) - testPropagation(('x || 'a >= abs('b) || 'y) && abs('b) >= 'a, false, - Seq('x || 'a === abs('b) || 'y, 'a <= abs('b))) - testPropagation(('x || 'a >= abs('b) || 'y) && abs('b) > 'a, false, Seq('x || 'y, 'a < abs('b))) - testPropagation(('x || 'a > abs('b) || 'y) && abs('b) < 'a, false, Seq(abs('b) < 'a)) - testPropagation(('x || 'a > abs('b) || 'y) && abs('b) <= 'a, false, - Seq('x || abs('b) < 'a || 'y, abs('b) <= 'a)) - testPropagation(('x || 'a > abs('b) || 'y) && abs('b) === 'a, false, - Seq('x || 'y, abs('b) === 'a)) - testPropagation(('x || 'a > abs('b) || 'y) && abs('b) <=> 'a, false, - Seq('x || 'y, abs('b) <=> 'a)) - testPropagation(('x || 'a > abs('b) || 'y) && abs('b) >= 'a, false, - Seq('x || 'y, 'a <= abs('b))) - testPropagation(('x || 'a > abs('b) || 'y) && abs('b) > 'a, false, Seq('x || 'y, 'a < abs('b))) - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterReductionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterReductionSuite.scala new file mode 100644 index 000000000000..ee823697267a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterReductionSuite.scala @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +/** + * Unit tests for constant propagation in expressions. + */ +class FilterReductionSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("AnalysisNodes", Once, + EliminateSubqueryAliases) :: + Batch("FilterReduction", FixedPoint(10), + ConstantPropagation, + FilterReduction, + ConstantFolding, + BooleanSimplification, + SimplifyBinaryComparison, + PruneFilters) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'x.boolean, 'y.boolean) + + val data = { + val intElements = Seq(null, 1, 2, 3) + val booleanElements = Seq(null, true, false) + for { + a <- intElements + b <- intElements + c <- intElements + x <- booleanElements + y <- booleanElements + } yield (a, b, c, x, y) + } + + val testRelationWithData = LocalRelation.fromExternalRows(testRelation.output, data.map(Row(_))) + + private def testPropagation( + input: Expression, + expectEmptyRelation: Boolean, + expectedConstraints: Seq[Expression] = Seq.empty) = { + val originalQuery = testRelationWithData.where(input).analyze + val optimized = Optimize.execute(originalQuery) + val correctAnswer = if (expectEmptyRelation) { + testRelation + } else { + testRelationWithData.where(expectedConstraints.reduce(And)).analyze + } + comparePlans(optimized, correctAnswer) + } + + test("Filter reduction") { + testPropagation('a < 2 && Literal(2) < 'a, true) + testPropagation('a < 2 && Literal(2) <= 'a, true) + testPropagation('a < 2 && Literal(2) === 'a, true) + testPropagation('a < 2 && Literal(2) <=> 'a, true) + testPropagation('a < 2 && Literal(2) >= 'a, false, Seq('a < 2)) + testPropagation('a < 2 && Literal(2) > 'a, false, Seq('a < 2)) + testPropagation('a <= 2 && Literal(2) < 'a, true) + testPropagation('a <= 2 && Literal(2) <= 'a, false, Seq('a === 2)) + testPropagation('a <= 2 && Literal(2) === 'a, false, Seq('a === 2)) + testPropagation('a <= 2 && Literal(2) <=> 'a, false, Seq('a <=> 2)) + testPropagation('a <= 2 && Literal(2) >= 'a, false, Seq('a <= 2)) + testPropagation('a <= 2 && Literal(2) > 'a, false, Seq('a < 2)) + testPropagation('a === 2 && Literal(2) < 'a, true) + testPropagation('a === 2 && Literal(2) <= 'a, false, Seq('a === 2)) + testPropagation('a === 2 && Literal(2) === 'a, false, Seq('a === 2)) + testPropagation('a === 2 && Literal(2) <=> 'a, false, Seq('a === 2)) + testPropagation('a === 2 && Literal(2) >= 'a, false, Seq('a === 2)) + testPropagation('a === 2 && Literal(2) > 'a, true) + testPropagation('a <=> 2 && Literal(2) < 'a, true) + testPropagation('a <=> 2 && Literal(2) <= 'a, false, Seq('a <=> 2)) + testPropagation('a <=> 2 && Literal(2) === 'a, false, Seq('a <=> 2)) + testPropagation('a <=> 2 && Literal(2) <=> 'a, false, Seq('a <=> 2)) + testPropagation('a <=> 2 && Literal(2) >= 'a, false, Seq('a <=> 2)) + testPropagation('a <=> 2 && Literal(2) > 'a, true) + testPropagation('a >= 2 && Literal(2) < 'a, false, Seq(Literal(2) < 'a)) + testPropagation('a >= 2 && Literal(2) <= 'a, false, Seq(Literal(2) <= 'a)) + testPropagation('a >= 2 && Literal(2) === 'a, false, Seq('a === 2)) + testPropagation('a >= 2 && Literal(2) <=> 'a, false, Seq('a <=> 2)) + testPropagation('a >= 2 && Literal(2) >= 'a, false, Seq('a === 2)) + testPropagation('a >= 2 && Literal(2) > 'a, true) + testPropagation('a > 2 && Literal(2) < 'a, false, Seq(Literal(2) < 'a)) + testPropagation('a > 2 && Literal(2) <= 'a, false, Seq(Literal(2) < 'a)) + testPropagation('a > 2 && Literal(2) === 'a, true) + testPropagation('a > 2 && Literal(2) <=> 'a, true) + testPropagation('a > 2 && Literal(2) >= 'a, true) + testPropagation('a > 2 && Literal(2) > 'a, true) + + testPropagation('a < 2 && Literal(3) < 'a, true) + testPropagation('a < 2 && Literal(3) <= 'a, true) + testPropagation('a < 2 && Literal(3) === 'a, true) + testPropagation('a < 2 && Literal(3) <=> 'a, true) + testPropagation('a < 2 && Literal(3) >= 'a, false, Seq('a < 2)) + testPropagation('a < 2 && Literal(3) > 'a, false, Seq('a < 2)) + testPropagation('a <= 2 && Literal(3) < 'a, true) + testPropagation('a <= 2 && Literal(3) <= 'a, true) + testPropagation('a <= 2 && Literal(3) === 'a, true) + testPropagation('a <= 2 && Literal(3) <=> 'a, true) + testPropagation('a <= 2 && Literal(3) >= 'a, false, Seq('a <= 2)) + testPropagation('a <= 2 && Literal(3) > 'a, false, Seq('a <= 2)) + testPropagation('a === 2 && Literal(3) < 'a, true) + testPropagation('a === 2 && Literal(3) <= 'a, true) + testPropagation('a === 2 && Literal(3) === 'a, true) + testPropagation('a === 2 && Literal(3) <=> 'a, true) + testPropagation('a === 2 && Literal(3) >= 'a, false, Seq('a === 2)) + testPropagation('a === 2 && Literal(3) > 'a, false, Seq('a === 2)) + testPropagation('a <=> 2 && Literal(3) < 'a, true) + testPropagation('a <=> 2 && Literal(3) <= 'a, true) + testPropagation('a <=> 2 && Literal(3) === 'a, true) + testPropagation('a <=> 2 && Literal(3) <=> 'a, true) + testPropagation('a <=> 2 && Literal(3) >= 'a, false, Seq('a <=> 2)) + testPropagation('a <=> 2 && Literal(3) > 'a, false, Seq('a <=> 2)) + testPropagation('a >= 2 && Literal(3) < 'a, false, Seq(Literal(3) < 'a)) + testPropagation('a >= 2 && Literal(3) <= 'a, false, Seq(Literal(3) <= 'a)) + testPropagation('a >= 2 && Literal(3) === 'a, false, Seq(Literal(3) === 'a)) + testPropagation('a >= 2 && Literal(3) <=> 'a, false, Seq(Literal(3) <=> 'a)) + testPropagation('a >= 2 && Literal(3) >= 'a, false, Seq(Literal(2) <= 'a, 'a <= 3)) + testPropagation('a >= 2 && Literal(3) > 'a, false, Seq(Literal(2) <= 'a, 'a < 3)) + testPropagation('a > 2 && Literal(3) < 'a, false, Seq(Literal(3) < 'a)) + testPropagation('a > 2 && Literal(3) <= 'a, false, Seq(Literal(3) <= 'a)) + testPropagation('a > 2 && Literal(3) === 'a, false, Seq(Literal(3) === 'a)) + testPropagation('a > 2 && Literal(3) <=> 'a, false, Seq(Literal(3) <=> 'a)) + testPropagation('a > 2 && Literal(3) >= 'a, false, Seq(Literal(2) < 'a, 'a <= 3)) + testPropagation('a > 2 && Literal(3) > 'a, false, Seq(Literal(2) < 'a, 'a < 3)) + + testPropagation('a < 3 && Literal(2) < 'a, false, Seq('a < 3 && Literal(2) < 'a)) + testPropagation('a < 3 && Literal(2) <= 'a, false, Seq('a < 3 && Literal(2) <= 'a)) + testPropagation('a < 3 && Literal(2) === 'a, false, Seq('a === 2)) + testPropagation('a < 3 && Literal(2) <=> 'a, false, Seq('a <=> 2)) + testPropagation('a < 3 && Literal(2) >= 'a, false, Seq('a <= 2)) + testPropagation('a < 3 && Literal(2) > 'a, false, Seq('a < 2)) + testPropagation('a <= 3 && Literal(2) < 'a, false, Seq('a <= 3 && Literal(2) < 'a)) + testPropagation('a <= 3 && Literal(2) <= 'a, false, Seq('a <= 3 && Literal(2) <= 'a)) + testPropagation('a <= 3 && Literal(2) === 'a, false, Seq('a === 2)) + testPropagation('a <= 3 && Literal(2) <=> 'a, false, Seq('a <=> 2)) + testPropagation('a <= 3 && Literal(2) >= 'a, false, Seq('a <= 2)) + testPropagation('a <= 3 && Literal(2) > 'a, false, Seq('a < 2)) + testPropagation('a === 3 && Literal(2) < 'a, false, Seq('a === 3)) + testPropagation('a === 3 && Literal(2) <= 'a, false, Seq('a === 3)) + testPropagation('a === 3 && Literal(2) === 'a, true) + testPropagation('a === 3 && Literal(2) <=> 'a, true) + testPropagation('a === 3 && Literal(2) >= 'a, true) + testPropagation('a === 3 && Literal(2) > 'a, true) + testPropagation('a <=> 3 && Literal(2) < 'a, false, Seq('a <=> 3)) + testPropagation('a <=> 3 && Literal(2) <= 'a, false, Seq('a <=> 3)) + testPropagation('a <=> 3 && Literal(2) === 'a, true) + testPropagation('a <=> 3 && Literal(2) <=> 'a, true) + testPropagation('a <=> 3 && Literal(2) >= 'a, true) + testPropagation('a <=> 3 && Literal(2) > 'a, true) + testPropagation('a >= 3 && Literal(2) < 'a, false, Seq(Literal(3) <= 'a)) + testPropagation('a >= 3 && Literal(2) <= 'a, false, Seq(Literal(3) <= 'a)) + testPropagation('a >= 3 && Literal(2) === 'a, true) + testPropagation('a >= 3 && Literal(2) <=> 'a, true) + testPropagation('a >= 3 && Literal(2) >= 'a, true) + testPropagation('a >= 3 && Literal(2) > 'a, true) + testPropagation('a > 3 && Literal(2) < 'a, false, Seq(Literal(3) < 'a)) + testPropagation('a > 3 && Literal(2) <= 'a, false, Seq(Literal(3) < 'a)) + testPropagation('a > 3 && Literal(2) === 'a, true) + testPropagation('a > 3 && Literal(2) <=> 'a, true) + testPropagation('a > 3 && Literal(2) >= 'a, true) + testPropagation('a > 3 && Literal(2) > 'a, true) + + testPropagation('a < 'b && 'b < 'a, true) + testPropagation('a < 'b && 'b <= 'a, true) + testPropagation('a < 'b && 'b === 'a, true) + testPropagation('a < 'b && 'b <=> 'a, true) + testPropagation('a < 'b && 'b >= 'a, false, Seq('a < 'b)) + testPropagation('a < 'b && 'b > 'a, false, Seq('a < 'b)) + testPropagation('a <= 'b && 'b < 'a, true) + testPropagation('a <= 'b && 'b <= 'a, false, Seq('a === 'b)) + testPropagation('a <= 'b && 'b === 'a, false, Seq('a === 'b)) + testPropagation('a <= 'b && 'b <=> 'a, false, Seq('a === 'b)) + testPropagation('a <= 'b && 'b >= 'a, false, Seq('a <= 'b)) + testPropagation('a <= 'b && 'b > 'a, false, Seq('a < 'b)) + testPropagation('a === 'b && 'b < 'a, true) + testPropagation('a === 'b && 'b <= 'a, false, Seq('a === 'b)) + testPropagation('a === 'b && 'b === 'a, false, Seq('a === 'b)) + testPropagation('a === 'b && 'b <=> 'a, false, Seq('a === 'b)) + testPropagation('a === 'b && 'b >= 'a, false, Seq('a === 'b)) + testPropagation('a === 'b && 'b > 'a, true) + testPropagation('a <=> 'b && 'b < 'a, true) + testPropagation('a <=> 'b && 'b <= 'a, false, Seq('a === 'b)) + testPropagation('a <=> 'b && 'b === 'a, false, Seq('a === 'b)) + testPropagation('a <=> 'b && 'b <=> 'a, false, Seq('a <=> 'b)) + testPropagation('a <=> 'b && 'b >= 'a, false, Seq('a === 'b)) + testPropagation('a <=> 'b && 'b > 'a, true) + testPropagation('a >= 'b && 'b < 'a, false, Seq('b < 'a)) + testPropagation('a >= 'b && 'b <= 'a, false, Seq('b <= 'a)) + testPropagation('a >= 'b && 'b === 'a, false, Seq('a === 'b)) + testPropagation('a >= 'b && 'b <=> 'a, false, Seq('a === 'b)) + testPropagation('a >= 'b && 'b >= 'a, false, Seq('a === 'b)) + testPropagation('a >= 'b && 'b > 'a, true) + testPropagation('a > 'b && 'b < 'a, false, Seq('b < 'a)) + testPropagation('a > 'b && 'b <= 'a, false, Seq('b < 'a)) + testPropagation('a > 'b && 'b === 'a, true) + testPropagation('a > 'b && 'b <=> 'a, true) + testPropagation('a > 'b && 'b >= 'a, true) + testPropagation('a > 'b && 'b > 'a, true) + + testPropagation('a < abs('b) && abs('b) < 'a, true) + testPropagation('a < abs('b) && abs('b) <= 'a, true) + testPropagation('a < abs('b) && abs('b) === 'a, true) + testPropagation('a < abs('b) && abs('b) <=> 'a, true) + testPropagation('a < abs('b) && abs('b) >= 'a, false, Seq('a < abs('b))) + testPropagation('a < abs('b) && abs('b) > 'a, false, Seq('a < abs('b))) + testPropagation('a <= abs('b) && abs('b) < 'a, true) + testPropagation('a <= abs('b) && abs('b) <= 'a, false, Seq('a === abs('b))) + testPropagation('a <= abs('b) && abs('b) === 'a, false, Seq('a === abs('b))) + testPropagation('a <= abs('b) && abs('b) <=> 'a, false, Seq('a === abs('b))) + testPropagation('a <= abs('b) && abs('b) >= 'a, false, Seq('a <= abs('b))) + testPropagation('a <= abs('b) && abs('b) > 'a, false, Seq('a < abs('b))) + testPropagation('a === abs('b) && abs('b) < 'a, true) + testPropagation('a === abs('b) && abs('b) <= 'a, false, Seq('a === abs('b))) + testPropagation('a === abs('b) && abs('b) === 'a, false, Seq('a === abs('b))) + testPropagation('a === abs('b) && abs('b) <=> 'a, false, Seq('a === abs('b))) + testPropagation('a === abs('b) && abs('b) >= 'a, false, Seq('a === abs('b))) + testPropagation('a === abs('b) && abs('b) > 'a, true) + testPropagation('a <=> abs('b) && abs('b) < 'a, true) + testPropagation('a <=> abs('b) && abs('b) <= 'a, false, Seq('a === abs('b))) + testPropagation('a <=> abs('b) && abs('b) === 'a, false, Seq('a === abs('b))) + testPropagation('a <=> abs('b) && abs('b) <=> 'a, false, Seq('a <=> abs('b))) + testPropagation('a <=> abs('b) && abs('b) >= 'a, false, Seq('a === abs('b))) + testPropagation('a <=> abs('b) && abs('b) > 'a, true) + testPropagation('a >= abs('b) && abs('b) < 'a, false, Seq(abs('b) < 'a)) + testPropagation('a >= abs('b) && abs('b) <= 'a, false, Seq(abs('b) <= 'a)) + testPropagation('a >= abs('b) && abs('b) === 'a, false, Seq('a === abs('b))) + testPropagation('a >= abs('b) && abs('b) <=> 'a, false, Seq('a === abs('b))) + testPropagation('a >= abs('b) && abs('b) >= 'a, false, Seq('a === abs('b))) + testPropagation('a >= abs('b) && abs('b) > 'a, true) + testPropagation('a > abs('b) && abs('b) < 'a, false, Seq(abs('b) < 'a)) + testPropagation('a > abs('b) && abs('b) <= 'a, false, Seq(abs('b) < 'a)) + testPropagation('a > abs('b) && abs('b) === 'a, true) + testPropagation('a > abs('b) && abs('b) <=> 'a, true) + testPropagation('a > abs('b) && abs('b) >= 'a, true) + testPropagation('a > abs('b) && abs('b) > 'a, true) + + testPropagation(('x || 'a < abs('b) || 'y) && abs('b) < 'a, false, Seq('x || 'y, abs('b) < 'a)) + testPropagation(('x || 'a < abs('b) || 'y) && abs('b) <= 'a, false, + Seq('x || 'y, abs('b) <= 'a)) + testPropagation(('x || 'a < abs('b) || 'y) && abs('b) === 'a, false, + Seq('x || 'y, abs('b) === 'a)) + testPropagation(('x || 'a < abs('b) || 'y) && abs('b) <=> 'a, false, + Seq('x || 'y, abs('b) <=> 'a)) + testPropagation(('x || 'a < abs('b) || 'y) && abs('b) >= 'a, false, + Seq('x || 'a < abs('b) || 'y, 'a <= abs('b))) + testPropagation(('x || 'a < abs('b) || 'y) && abs('b) > 'a, false, Seq('a < abs('b))) + testPropagation(('x || 'a <= abs('b) || 'y) && abs('b) < 'a, false, Seq('x || 'y, abs('b) < 'a)) + testPropagation(('x || 'a <= abs('b) || 'y) && abs('b) <= 'a, false, + Seq('x || 'a === abs('b) || 'y, abs('b) <= 'a)) + testPropagation(('x || 'a <= abs('b) || 'y) && abs('b) === 'a, false, Seq(abs('b) === 'a)) + testPropagation(('x || 'a <= abs('b) || 'y) && abs('b) <=> 'a, false, + Seq('x || 'a === abs('b) || 'y, abs('b) <=> 'a)) + testPropagation(('x || 'a <= abs('b) || 'y) && abs('b) >= 'a, false, Seq('a <= abs('b))) + testPropagation(('x || 'a <= abs('b) || 'y) && abs('b) > 'a, false, Seq('a < abs('b))) + testPropagation(('x || 'a === abs('b) || 'y) && abs('b) < 'a, false, + Seq('x || 'y, abs('b) < 'a)) + testPropagation(('x || 'a === abs('b) || 'y) && abs('b) <= 'a, false, + Seq('x || 'a === abs('b) || 'y, abs('b) <= 'a)) + testPropagation(('x || 'a === abs('b) || 'y) && abs('b) === 'a, false, Seq(abs('b) === 'a)) + testPropagation(('x || 'a === abs('b) || 'y) && abs('b) <=> 'a, false, + Seq('x || 'a === abs('b) || 'y, abs('b) <=> 'a)) + testPropagation(('x || 'a === abs('b) || 'y) && abs('b) >= 'a, false, + Seq('x || 'a === abs('b) || 'y, 'a <= abs('b))) + testPropagation(('x || 'a === abs('b) || 'y) && abs('b) > 'a, false, + Seq('x || 'y, 'a < abs('b))) + testPropagation(('x || 'a <=> abs('b) || 'y) && abs('b) < 'a, false, + Seq('x || 'y, abs('b) < 'a)) + testPropagation(('x || 'a <=> abs('b) || 'y) && abs('b) <= 'a, false, + Seq('x || 'a === abs('b) || 'y, abs('b) <= 'a)) + testPropagation(('x || 'a <=> abs('b) || 'y) && abs('b) === 'a, false, Seq(abs('b) === 'a)) + testPropagation(('x || 'a <=> abs('b) || 'y) && abs('b) <=> 'a, false, Seq(abs('b) <=> 'a)) + testPropagation(('x || 'a <=> abs('b) || 'y) && abs('b) >= 'a, false, + Seq('x || 'a === abs('b) || 'y, 'a <= abs('b))) + testPropagation(('x || 'a <=> abs('b) || 'y) && abs('b) > 'a, false, + Seq('x || 'y, 'a < abs('b))) + testPropagation(('x || 'a >= abs('b) || 'y) && abs('b) < 'a, false, Seq(abs('b) < 'a)) + testPropagation(('x || 'a >= abs('b) || 'y) && abs('b) <= 'a, false, Seq(abs('b) <= 'a)) + testPropagation(('x || 'a >= abs('b) || 'y) && abs('b) === 'a, false, Seq(abs('b) === 'a)) + testPropagation(('x || 'a >= abs('b) || 'y) && abs('b) <=> 'a, false, + Seq('x || 'a === abs('b) || 'y, abs('b) <=> 'a)) + testPropagation(('x || 'a >= abs('b) || 'y) && abs('b) >= 'a, false, + Seq('x || 'a === abs('b) || 'y, 'a <= abs('b))) + testPropagation(('x || 'a >= abs('b) || 'y) && abs('b) > 'a, false, Seq('x || 'y, 'a < abs('b))) + testPropagation(('x || 'a > abs('b) || 'y) && abs('b) < 'a, false, Seq(abs('b) < 'a)) + testPropagation(('x || 'a > abs('b) || 'y) && abs('b) <= 'a, false, + Seq('x || abs('b) < 'a || 'y, abs('b) <= 'a)) + testPropagation(('x || 'a > abs('b) || 'y) && abs('b) === 'a, false, + Seq('x || 'y, abs('b) === 'a)) + testPropagation(('x || 'a > abs('b) || 'y) && abs('b) <=> 'a, false, + Seq('x || 'y, abs('b) <=> 'a)) + testPropagation(('x || 'a > abs('b) || 'y) && abs('b) >= 'a, false, + Seq('x || 'y, 'a <= abs('b))) + testPropagation(('x || 'a > abs('b) || 'y) && abs('b) > 'a, false, Seq('x || 'y, 'a < abs('b))) + } +} From 9ee3be3642dc285fbb618f04343d52cc1328eeee Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Mon, 6 May 2019 14:29:11 +0200 Subject: [PATCH 05/12] revert unnecessary changes, minor fix Change-Id: Ia68242ba5067e84143e35a1126a49017c5a203f6 --- .../org/apache/spark/sql/catalyst/optimizer/expressions.scala | 4 +--- .../apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala | 3 +-- .../catalyst/optimizer/InferFiltersFromConstraintsSuite.scala | 2 +- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index ed3fb0a9e187..146b0f859fdc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -60,7 +60,7 @@ object ConstantFolding extends Rule[LogicalPlan] { * eg. * {{{ * SELECT * FROM table WHERE i = 5 AND j = i + 3 => ... WHERE i = 5 AND j = 8 - * SELECT * FROM table WHERE abs(i) = 5 AND j <= abs(i) + 3 => ... WHERE i = 5 AND j <= 8 + * SELECT * FROM table WHERE abs(i) = 5 AND j <= abs(i) + 3 => ... WHERE abs(i) = 5 AND j <= 8 * }}} * * Approach used: @@ -398,8 +398,6 @@ object SimplifyBinaryComparison extends Rule[LogicalPlan] with PredicateHelper { case q: LogicalPlan => q transformExpressionsUp { // True with equality case a EqualNullSafe b if a.semanticEquals(b) => TrueLiteral -// case a EqualNullSafe b if a.foldable && !a.nullable || b.foldable && !b.nullable => -// EqualTo(a, b) case a EqualTo b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral case a GreaterThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 9aa3177a9063..db272b3d415e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -173,8 +173,7 @@ abstract class UnaryNode extends LogicalPlan { var allConstraints = child.constraints.asInstanceOf[Set[Expression]] projectList.foreach { case a @ Alias(l: Literal, _) => - allConstraints += - (if (l.nullable) EqualNullSafe(a.toAttribute, l) else EqualTo(a.toAttribute, l)) + allConstraints += EqualNullSafe(a.toAttribute, l) case a @ Alias(e, _) => // For every alias in `projectList`, replace the reference in constraints by its attribute. allConstraints ++= allConstraints.map(_ transform { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index e4671f0d1cce..a40ba2dc38b7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -196,7 +196,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { test("constraints should be inferred from aliased literals") { val originalLeft = testRelation.subquery('left).as("left") - val optimizedLeft = testRelation.subquery('left).where(IsNotNull('a) && 'a === 2).as("left") + val optimizedLeft = testRelation.subquery('left).where(IsNotNull('a) && 'a <=> 2).as("left") val right = Project(Seq(Literal(2).as("two")), testRelation.subquery('right)).as("right") val condition = Some("left.a".attr === "right.two".attr) From 169563ba80f277c40ae52b34ae582bb63047faac Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 7 May 2019 13:14:07 +0200 Subject: [PATCH 06/12] revert accidentally removed logic, add more tests Change-Id: Ice97dbafb67bf6cf84107b9064f42be991cbd9f8 --- .../plans/logical/QueryPlanConstraints.scala | 47 +++- .../optimizer/FilterReductionSuite.scala | 224 ++++++++++-------- 2 files changed, 163 insertions(+), 108 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index 2014ce1dd67a..e13899a4c98c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -224,19 +224,32 @@ trait ConstraintHelper { if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => Literal.FalseLiteral - case c EqualTo d if planEqual(b, d) && planEqual(a, c) => Literal.FalseLiteral - case c EqualTo d if planEqual(b, c) && planEqual(a, d) => Literal.FalseLiteral - case c EqualTo d if planEqual(a, c) && planEqual(b, d) => Literal.FalseLiteral - case c EqualTo d if planEqual(a, d) && planEqual(b, c) => Literal.FalseLiteral + case c EqualTo d if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => + Literal.FalseLiteral + case c EqualTo d if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => + Literal.FalseLiteral + case c EqualTo d if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => + Literal.FalseLiteral + case c EqualTo d if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => + Literal.FalseLiteral - case c EqualNullSafe d if planEqual(b, d) => - if (planEqual(a, c)) Literal.FalseLiteral else EqualTo(c, d) - case c EqualNullSafe d if planEqual(b, c) => - if (planEqual(a, d)) Literal.FalseLiteral else EqualTo(c, d) - case c EqualNullSafe d if planEqual(a, c) => - if (planEqual(b, d)) Literal.FalseLiteral else EqualTo(c, d) - case c EqualNullSafe d if planEqual(a, d) => - if (planEqual(b, c)) Literal.FalseLiteral else EqualTo(c, d) + case c EqualNullSafe d + if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => + Literal.FalseLiteral + case c EqualNullSafe d + if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => + Literal.FalseLiteral + case c EqualNullSafe d + if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => + Literal.FalseLiteral + case c EqualNullSafe d + if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => + Literal.FalseLiteral + + case c EqualNullSafe d if planEqual(b, d) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(b, c) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, c) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, d) => EqualTo(c, d) } case a LessThanOrEqual b => expression transformUp { case c LessThan d if planEqual(b, d) && valueLessThan(c, a) => @@ -263,6 +276,16 @@ trait ConstraintHelper { case c LessThanOrEqual d if planEqual(a, d) && (planEqual(b, c) || valueEqual(b, c)) => EqualTo(c, d) + case c EqualTo d if planEqual(b, d) && valueLessThan(c, a) => Literal.FalseLiteral + case c EqualTo d if planEqual(b, c) && valueLessThan(d, a) => Literal.FalseLiteral + case c EqualTo d if planEqual(a, c) && valueLessThan(b, d) => Literal.FalseLiteral + case c EqualTo d if planEqual(a, d) && valueLessThan(b, c) => Literal.FalseLiteral + + case c EqualNullSafe d if planEqual(b, d) && valueLessThan(c, a) => Literal.FalseLiteral + case c EqualNullSafe d if planEqual(b, c) && valueLessThan(d, a) => Literal.FalseLiteral + case c EqualNullSafe d if planEqual(a, c) && valueLessThan(b, d) => Literal.FalseLiteral + case c EqualNullSafe d if planEqual(a, d) && valueLessThan(b, c) => Literal.FalseLiteral + case c EqualNullSafe d if planEqual(b, d) => EqualTo(c, d) case c EqualNullSafe d if planEqual(b, c) => EqualTo(c, d) case c EqualNullSafe d if planEqual(a, c) => EqualTo(c, d) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterReductionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterReductionSuite.scala index ee823697267a..3e711ca39ac5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterReductionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterReductionSuite.scala @@ -44,7 +44,7 @@ class FilterReductionSuite extends PlanTest { PruneFilters) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'x.boolean, 'y.boolean) + val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'x.boolean) val data = { val intElements = Seq(null, 1, 2, 3) @@ -54,8 +54,7 @@ class FilterReductionSuite extends PlanTest { b <- intElements c <- intElements x <- booleanElements - y <- booleanElements - } yield (a, b, c, x, y) + } yield (a, b, c, x) } val testRelationWithData = LocalRelation.fromExternalRows(testRelation.output, data.map(Row(_))) @@ -112,6 +111,48 @@ class FilterReductionSuite extends PlanTest { testPropagation('a > 2 && Literal(2) >= 'a, true) testPropagation('a > 2 && Literal(2) > 'a, true) + testPropagation(('x || 'a < 2) && Literal(2) < 'a, false, Seq('x, Literal(2) < 'a)) + testPropagation(('x || 'a < 2) && Literal(2) <= 'a, false, Seq('x, Literal(2) <= 'a)) + testPropagation(('x || 'a < 2) && Literal(2) === 'a, false, Seq('x, Literal(2) === 'a)) + testPropagation(('x || 'a < 2) && Literal(2) <=> 'a, false, Seq('x, Literal(2) <=> 'a)) + testPropagation(('x || 'a < 2) && Literal(2) >= 'a, false, Seq('x || 'a < 2, 'a <= 2)) + testPropagation(('x || 'a < 2) && Literal(2) > 'a, false, Seq('a < 2)) + testPropagation(('x || 'a <= 2) && Literal(2) < 'a, false, Seq('x, Literal(2) < 'a)) + testPropagation(('x || 'a <= 2) && Literal(2) <= 'a, false, + Seq('x || 'a === 2, Literal(2) <= 'a)) + testPropagation(('x || 'a <= 2) && Literal(2) === 'a, false, Seq(Literal(2) === 'a)) + testPropagation(('x || 'a <= 2) && Literal(2) <=> 'a, false, Seq(Literal(2) <=> 'a)) + testPropagation(('x || 'a <= 2) && Literal(2) >= 'a, false, Seq('a <= 2)) + testPropagation(('x || 'a <= 2) && Literal(2) > 'a, false, Seq('a < 2)) + testPropagation(('x || 'a === 2) && Literal(2) < 'a, false, Seq('x, Literal(2) < 'a)) + testPropagation(('x || 'a === 2) && Literal(2) <= 'a, false, + Seq('x || 'a === 2, Literal(2) <= 'a)) + testPropagation(('x || 'a === 2) && Literal(2) === 'a, false, Seq(Literal(2) === 'a)) + testPropagation(('x || 'a === 2) && Literal(2) <=> 'a, false, Seq(Literal(2) <=> 'a)) + testPropagation(('x || 'a === 2) && Literal(2) >= 'a, false, Seq('x || 'a === 2, 'a <= 2)) + testPropagation(('x || 'a === 2) && Literal(2) > 'a, false, Seq('x, 'a < 2)) + testPropagation(('x || 'a <=> 2) && Literal(2) < 'a, false, Seq('x, Literal(2) < 'a)) + testPropagation(('x || 'a <=> 2) && Literal(2) <= 'a, false, + Seq('x || 'a === 2, Literal(2) <= 'a)) + testPropagation(('x || 'a <=> 2) && Literal(2) === 'a, false, Seq(Literal(2) === 'a)) + testPropagation(('x || 'a <=> 2) && Literal(2) <=> 'a, false, Seq(Literal(2) <=> 'a)) + testPropagation(('x || 'a <=> 2) && Literal(2) >= 'a, false, Seq('x || 'a === 2, 'a <= 2)) + testPropagation(('x || 'a <=> 2) && Literal(2) > 'a, false, Seq('x, 'a < 2)) + testPropagation(('x || 'a >= 2) && Literal(2) < 'a, false, Seq(Literal(2) < 'a)) + testPropagation(('x || 'a >= 2) && Literal(2) <= 'a, false, Seq(Literal(2) <= 'a)) + testPropagation(('x || 'a >= 2) && Literal(2) === 'a, false, Seq(Literal(2) === 'a)) + testPropagation(('x || 'a >= 2) && Literal(2) <=> 'a, false, Seq(Literal(2) <=> 'a)) + testPropagation(('x || 'a >= 2) && Literal(2) >= 'a, false, + Seq('x || 'a === 2, 'a <= Literal(2))) + testPropagation(('x || 'a >= 2) && Literal(2) > 'a, false, Seq('x, 'a < Literal(2))) + testPropagation(('x || 'a > 2) && Literal(2) < 'a, false, Seq(Literal(2) < 'a)) + testPropagation(('x || 'a > 2) && Literal(2) <= 'a, false, + Seq('x || Literal(2) < 'a, Literal(2) <= 'a)) + testPropagation(('x || 'a > 2) && Literal(2) === 'a, false, Seq('x, Literal(2) === 'a)) + testPropagation(('x || 'a > 2) && Literal(2) <=> 'a, false, Seq('x, Literal(2) <=> 'a)) + testPropagation(('x || 'a > 2) && Literal(2) >= 'a, false, Seq('x, 'a <= Literal(2))) + testPropagation(('x || 'a > 2) && Literal(2) > 'a, false, Seq('x, 'a < Literal(2))) + testPropagation('a < 2 && Literal(3) < 'a, true) testPropagation('a < 2 && Literal(3) <= 'a, true) testPropagation('a < 2 && Literal(3) === 'a, true) @@ -149,42 +190,43 @@ class FilterReductionSuite extends PlanTest { testPropagation('a > 2 && Literal(3) >= 'a, false, Seq(Literal(2) < 'a, 'a <= 3)) testPropagation('a > 2 && Literal(3) > 'a, false, Seq(Literal(2) < 'a, 'a < 3)) - testPropagation('a < 3 && Literal(2) < 'a, false, Seq('a < 3 && Literal(2) < 'a)) - testPropagation('a < 3 && Literal(2) <= 'a, false, Seq('a < 3 && Literal(2) <= 'a)) - testPropagation('a < 3 && Literal(2) === 'a, false, Seq('a === 2)) - testPropagation('a < 3 && Literal(2) <=> 'a, false, Seq('a <=> 2)) - testPropagation('a < 3 && Literal(2) >= 'a, false, Seq('a <= 2)) - testPropagation('a < 3 && Literal(2) > 'a, false, Seq('a < 2)) - testPropagation('a <= 3 && Literal(2) < 'a, false, Seq('a <= 3 && Literal(2) < 'a)) - testPropagation('a <= 3 && Literal(2) <= 'a, false, Seq('a <= 3 && Literal(2) <= 'a)) - testPropagation('a <= 3 && Literal(2) === 'a, false, Seq('a === 2)) - testPropagation('a <= 3 && Literal(2) <=> 'a, false, Seq('a <=> 2)) - testPropagation('a <= 3 && Literal(2) >= 'a, false, Seq('a <= 2)) - testPropagation('a <= 3 && Literal(2) > 'a, false, Seq('a < 2)) - testPropagation('a === 3 && Literal(2) < 'a, false, Seq('a === 3)) - testPropagation('a === 3 && Literal(2) <= 'a, false, Seq('a === 3)) - testPropagation('a === 3 && Literal(2) === 'a, true) - testPropagation('a === 3 && Literal(2) <=> 'a, true) - testPropagation('a === 3 && Literal(2) >= 'a, true) - testPropagation('a === 3 && Literal(2) > 'a, true) - testPropagation('a <=> 3 && Literal(2) < 'a, false, Seq('a <=> 3)) - testPropagation('a <=> 3 && Literal(2) <= 'a, false, Seq('a <=> 3)) - testPropagation('a <=> 3 && Literal(2) === 'a, true) - testPropagation('a <=> 3 && Literal(2) <=> 'a, true) - testPropagation('a <=> 3 && Literal(2) >= 'a, true) - testPropagation('a <=> 3 && Literal(2) > 'a, true) - testPropagation('a >= 3 && Literal(2) < 'a, false, Seq(Literal(3) <= 'a)) - testPropagation('a >= 3 && Literal(2) <= 'a, false, Seq(Literal(3) <= 'a)) - testPropagation('a >= 3 && Literal(2) === 'a, true) - testPropagation('a >= 3 && Literal(2) <=> 'a, true) - testPropagation('a >= 3 && Literal(2) >= 'a, true) - testPropagation('a >= 3 && Literal(2) > 'a, true) - testPropagation('a > 3 && Literal(2) < 'a, false, Seq(Literal(3) < 'a)) - testPropagation('a > 3 && Literal(2) <= 'a, false, Seq(Literal(3) < 'a)) - testPropagation('a > 3 && Literal(2) === 'a, true) - testPropagation('a > 3 && Literal(2) <=> 'a, true) - testPropagation('a > 3 && Literal(2) >= 'a, true) - testPropagation('a > 3 && Literal(2) > 'a, true) + testPropagation(('x || 'a < 2) && Literal(3) < 'a, false, Seq('x, Literal(3) < 'a)) + testPropagation(('x || 'a < 2) && Literal(3) <= 'a, false, Seq('x, Literal(3) <= 'a)) + testPropagation(('x || 'a < 2) && Literal(3) === 'a, false, Seq('x, Literal(3) === 'a)) + testPropagation(('x || 'a < 2) && Literal(3) <=> 'a, false, Seq('x, Literal(3) <=> 'a)) + testPropagation(('x || 'a < 2) && Literal(3) >= 'a, false, Seq('x || 'a < 2, 'a <= 3)) + testPropagation(('x || 'a < 2) && Literal(3) > 'a, false, Seq('x || 'a < 2, 'a < 3)) + testPropagation(('x || 'a <= 2) && Literal(3) < 'a, false, Seq('x, Literal(3) < 'a)) + testPropagation(('x || 'a <= 2) && Literal(3) <= 'a, false, Seq('x, Literal(3) <= 'a)) + testPropagation(('x || 'a <= 2) && Literal(3) === 'a, false, Seq('x, Literal(3) === 'a)) + testPropagation(('x || 'a <= 2) && Literal(3) <=> 'a, false, Seq('x, Literal(3) <=> 'a)) + testPropagation(('x || 'a <= 2) && Literal(3) >= 'a, false, Seq('x || 'a <= 2, 'a <= 3)) + testPropagation(('x || 'a <= 2) && Literal(3) > 'a, false, Seq('x || 'a <= 2, 'a < 3)) + testPropagation(('x || 'a === 2) && Literal(3) < 'a, false, Seq('x, Literal(3) < 'a)) + testPropagation(('x || 'a === 2) && Literal(3) <= 'a, false, Seq('x, Literal(3) <= 'a)) + testPropagation(('x || 'a === 2) && Literal(3) === 'a, false, Seq('x, Literal(3) === 'a)) + testPropagation(('x || 'a === 2) && Literal(3) <=> 'a, false, Seq('x, Literal(3) <=> 'a)) + testPropagation(('x || 'a === 2) && Literal(3) >= 'a, false, Seq('x || 'a === 2, 'a <= 3)) + testPropagation(('x || 'a === 2) && Literal(3) > 'a, false, Seq('x || 'a === 2, 'a < 3)) + testPropagation(('x || 'a <=> 2) && Literal(3) < 'a, false, Seq('x, Literal(3) < 'a)) + testPropagation(('x || 'a <=> 2) && Literal(3) <= 'a, false, Seq('x, Literal(3) <= 'a)) + testPropagation(('x || 'a <=> 2) && Literal(3) === 'a, false, Seq('x, Literal(3) === 'a)) + testPropagation(('x || 'a <=> 2) && Literal(3) <=> 'a, false, Seq('x, Literal(3) <=> 'a)) + testPropagation(('x || 'a <=> 2) && Literal(3) >= 'a, false, Seq('x || 'a === 2, 'a <= 3)) + testPropagation(('x || 'a <=> 2) && Literal(3) > 'a, false, Seq('x || 'a === 2, 'a < 3)) + testPropagation(('x || 'a >= 2) && Literal(3) < 'a, false, Seq(Literal(3) < 'a)) + testPropagation(('x || 'a >= 2) && Literal(3) <= 'a, false, Seq(Literal(3) <= 'a)) + testPropagation(('x || 'a >= 2) && Literal(3) === 'a, false, Seq(Literal(3) === 'a)) + testPropagation(('x || 'a >= 2) && Literal(3) <=> 'a, false, Seq(Literal(3) <=> 'a)) + testPropagation(('x || 'a >= 2) && Literal(3) >= 'a, false, + Seq('x || Literal(2) <= 'a, 'a <= 3)) + testPropagation(('x || 'a >= 2) && Literal(3) > 'a, false, Seq('x || Literal(2) <= 'a, 'a < 3)) + testPropagation(('x || 'a > 2) && Literal(3) < 'a, false, Seq(Literal(3) < 'a)) + testPropagation(('x || 'a > 2) && Literal(3) <= 'a, false, Seq(Literal(3) <= 'a)) + testPropagation(('x || 'a > 2) && Literal(3) === 'a, false, Seq(Literal(3) === 'a)) + testPropagation(('x || 'a > 2) && Literal(3) <=> 'a, false, Seq(Literal(3) <=> 'a)) + testPropagation(('x || 'a > 2) && Literal(3) >= 'a, false, Seq('x || Literal(2) < 'a, 'a <= 3)) + testPropagation(('x || 'a > 2) && Literal(3) > 'a, false, Seq('x || Literal(2) < 'a, 'a < 3)) testPropagation('a < 'b && 'b < 'a, true) testPropagation('a < 'b && 'b <= 'a, true) @@ -260,62 +302,52 @@ class FilterReductionSuite extends PlanTest { testPropagation('a > abs('b) && abs('b) >= 'a, true) testPropagation('a > abs('b) && abs('b) > 'a, true) - testPropagation(('x || 'a < abs('b) || 'y) && abs('b) < 'a, false, Seq('x || 'y, abs('b) < 'a)) - testPropagation(('x || 'a < abs('b) || 'y) && abs('b) <= 'a, false, - Seq('x || 'y, abs('b) <= 'a)) - testPropagation(('x || 'a < abs('b) || 'y) && abs('b) === 'a, false, - Seq('x || 'y, abs('b) === 'a)) - testPropagation(('x || 'a < abs('b) || 'y) && abs('b) <=> 'a, false, - Seq('x || 'y, abs('b) <=> 'a)) - testPropagation(('x || 'a < abs('b) || 'y) && abs('b) >= 'a, false, - Seq('x || 'a < abs('b) || 'y, 'a <= abs('b))) - testPropagation(('x || 'a < abs('b) || 'y) && abs('b) > 'a, false, Seq('a < abs('b))) - testPropagation(('x || 'a <= abs('b) || 'y) && abs('b) < 'a, false, Seq('x || 'y, abs('b) < 'a)) - testPropagation(('x || 'a <= abs('b) || 'y) && abs('b) <= 'a, false, - Seq('x || 'a === abs('b) || 'y, abs('b) <= 'a)) - testPropagation(('x || 'a <= abs('b) || 'y) && abs('b) === 'a, false, Seq(abs('b) === 'a)) - testPropagation(('x || 'a <= abs('b) || 'y) && abs('b) <=> 'a, false, - Seq('x || 'a === abs('b) || 'y, abs('b) <=> 'a)) - testPropagation(('x || 'a <= abs('b) || 'y) && abs('b) >= 'a, false, Seq('a <= abs('b))) - testPropagation(('x || 'a <= abs('b) || 'y) && abs('b) > 'a, false, Seq('a < abs('b))) - testPropagation(('x || 'a === abs('b) || 'y) && abs('b) < 'a, false, - Seq('x || 'y, abs('b) < 'a)) - testPropagation(('x || 'a === abs('b) || 'y) && abs('b) <= 'a, false, - Seq('x || 'a === abs('b) || 'y, abs('b) <= 'a)) - testPropagation(('x || 'a === abs('b) || 'y) && abs('b) === 'a, false, Seq(abs('b) === 'a)) - testPropagation(('x || 'a === abs('b) || 'y) && abs('b) <=> 'a, false, - Seq('x || 'a === abs('b) || 'y, abs('b) <=> 'a)) - testPropagation(('x || 'a === abs('b) || 'y) && abs('b) >= 'a, false, - Seq('x || 'a === abs('b) || 'y, 'a <= abs('b))) - testPropagation(('x || 'a === abs('b) || 'y) && abs('b) > 'a, false, - Seq('x || 'y, 'a < abs('b))) - testPropagation(('x || 'a <=> abs('b) || 'y) && abs('b) < 'a, false, - Seq('x || 'y, abs('b) < 'a)) - testPropagation(('x || 'a <=> abs('b) || 'y) && abs('b) <= 'a, false, - Seq('x || 'a === abs('b) || 'y, abs('b) <= 'a)) - testPropagation(('x || 'a <=> abs('b) || 'y) && abs('b) === 'a, false, Seq(abs('b) === 'a)) - testPropagation(('x || 'a <=> abs('b) || 'y) && abs('b) <=> 'a, false, Seq(abs('b) <=> 'a)) - testPropagation(('x || 'a <=> abs('b) || 'y) && abs('b) >= 'a, false, - Seq('x || 'a === abs('b) || 'y, 'a <= abs('b))) - testPropagation(('x || 'a <=> abs('b) || 'y) && abs('b) > 'a, false, - Seq('x || 'y, 'a < abs('b))) - testPropagation(('x || 'a >= abs('b) || 'y) && abs('b) < 'a, false, Seq(abs('b) < 'a)) - testPropagation(('x || 'a >= abs('b) || 'y) && abs('b) <= 'a, false, Seq(abs('b) <= 'a)) - testPropagation(('x || 'a >= abs('b) || 'y) && abs('b) === 'a, false, Seq(abs('b) === 'a)) - testPropagation(('x || 'a >= abs('b) || 'y) && abs('b) <=> 'a, false, - Seq('x || 'a === abs('b) || 'y, abs('b) <=> 'a)) - testPropagation(('x || 'a >= abs('b) || 'y) && abs('b) >= 'a, false, - Seq('x || 'a === abs('b) || 'y, 'a <= abs('b))) - testPropagation(('x || 'a >= abs('b) || 'y) && abs('b) > 'a, false, Seq('x || 'y, 'a < abs('b))) - testPropagation(('x || 'a > abs('b) || 'y) && abs('b) < 'a, false, Seq(abs('b) < 'a)) - testPropagation(('x || 'a > abs('b) || 'y) && abs('b) <= 'a, false, - Seq('x || abs('b) < 'a || 'y, abs('b) <= 'a)) - testPropagation(('x || 'a > abs('b) || 'y) && abs('b) === 'a, false, - Seq('x || 'y, abs('b) === 'a)) - testPropagation(('x || 'a > abs('b) || 'y) && abs('b) <=> 'a, false, - Seq('x || 'y, abs('b) <=> 'a)) - testPropagation(('x || 'a > abs('b) || 'y) && abs('b) >= 'a, false, - Seq('x || 'y, 'a <= abs('b))) - testPropagation(('x || 'a > abs('b) || 'y) && abs('b) > 'a, false, Seq('x || 'y, 'a < abs('b))) + testPropagation(('x || 'a < abs('b)) && abs('b) < 'a, false, Seq('x, abs('b) < 'a)) + testPropagation(('x || 'a < abs('b)) && abs('b) <= 'a, false, Seq('x, abs('b) <= 'a)) + testPropagation(('x || 'a < abs('b)) && abs('b) === 'a, false, Seq('x, abs('b) === 'a)) + testPropagation(('x || 'a < abs('b)) && abs('b) <=> 'a, false, Seq('x, abs('b) <=> 'a)) + testPropagation(('x || 'a < abs('b)) && abs('b) >= 'a, false, + Seq('x || 'a < abs('b), 'a <= abs('b))) + testPropagation(('x || 'a < abs('b)) && abs('b) > 'a, false, Seq('a < abs('b))) + testPropagation(('x || 'a <= abs('b)) && abs('b) < 'a, false, Seq('x, abs('b) < 'a)) + testPropagation(('x || 'a <= abs('b)) && abs('b) <= 'a, false, + Seq('x || 'a === abs('b), abs('b) <= 'a)) + testPropagation(('x || 'a <= abs('b)) && abs('b) === 'a, false, Seq(abs('b) === 'a)) + testPropagation(('x || 'a <= abs('b)) && abs('b) <=> 'a, false, + Seq('x || 'a === abs('b), abs('b) <=> 'a)) + testPropagation(('x || 'a <= abs('b)) && abs('b) >= 'a, false, Seq('a <= abs('b))) + testPropagation(('x || 'a <= abs('b)) && abs('b) > 'a, false, Seq('a < abs('b))) + testPropagation(('x || 'a === abs('b)) && abs('b) < 'a, false, Seq('x, abs('b) < 'a)) + testPropagation(('x || 'a === abs('b)) && abs('b) <= 'a, false, + Seq('x || 'a === abs('b), abs('b) <= 'a)) + testPropagation(('x || 'a === abs('b)) && abs('b) === 'a, false, Seq(abs('b) === 'a)) + testPropagation(('x || 'a === abs('b)) && abs('b) <=> 'a, false, + Seq('x || 'a === abs('b), abs('b) <=> 'a)) + testPropagation(('x || 'a === abs('b)) && abs('b) >= 'a, false, + Seq('x || 'a === abs('b), 'a <= abs('b))) + testPropagation(('x || 'a === abs('b)) && abs('b) > 'a, false, Seq('x, 'a < abs('b))) + testPropagation(('x || 'a <=> abs('b)) && abs('b) < 'a, false, Seq('x, abs('b) < 'a)) + testPropagation(('x || 'a <=> abs('b)) && abs('b) <= 'a, false, + Seq('x || 'a === abs('b), abs('b) <= 'a)) + testPropagation(('x || 'a <=> abs('b)) && abs('b) === 'a, false, Seq(abs('b) === 'a)) + testPropagation(('x || 'a <=> abs('b)) && abs('b) <=> 'a, false, Seq(abs('b) <=> 'a)) + testPropagation(('x || 'a <=> abs('b)) && abs('b) >= 'a, false, + Seq('x || 'a === abs('b), 'a <= abs('b))) + testPropagation(('x || 'a <=> abs('b)) && abs('b) > 'a, false, Seq('x, 'a < abs('b))) + testPropagation(('x || 'a >= abs('b)) && abs('b) < 'a, false, Seq(abs('b) < 'a)) + testPropagation(('x || 'a >= abs('b)) && abs('b) <= 'a, false, Seq(abs('b) <= 'a)) + testPropagation(('x || 'a >= abs('b)) && abs('b) === 'a, false, Seq(abs('b) === 'a)) + testPropagation(('x || 'a >= abs('b)) && abs('b) <=> 'a, false, + Seq('x || 'a === abs('b), abs('b) <=> 'a)) + testPropagation(('x || 'a >= abs('b)) && abs('b) >= 'a, false, + Seq('x || 'a === abs('b), 'a <= abs('b))) + testPropagation(('x || 'a >= abs('b)) && abs('b) > 'a, false, Seq('x, 'a < abs('b))) + testPropagation(('x || 'a > abs('b)) && abs('b) < 'a, false, Seq(abs('b) < 'a)) + testPropagation(('x || 'a > abs('b)) && abs('b) <= 'a, false, + Seq('x || abs('b) < 'a, abs('b) <= 'a)) + testPropagation(('x || 'a > abs('b)) && abs('b) === 'a, false, Seq('x, abs('b) === 'a)) + testPropagation(('x || 'a > abs('b)) && abs('b) <=> 'a, false, Seq('x, abs('b) <=> 'a)) + testPropagation(('x || 'a > abs('b)) && abs('b) >= 'a, false, Seq('x, 'a <= abs('b))) + testPropagation(('x || 'a > abs('b)) && abs('b) > 'a, false, Seq('x, 'a < abs('b))) } } From 3c982b6e64d7ec8a93409f464b4ab179ee96ac43 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 7 May 2019 14:05:42 +0200 Subject: [PATCH 07/12] use map of equality predicates in ConstantPropagation --- .../sql/catalyst/optimizer/expressions.scala | 34 +++++++++++-------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 146b0f859fdc..8eaed68ca7c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet -import scala.collection.mutable.{ArrayBuffer, Stack} +import scala.collection.mutable.{ArrayBuffer, Map, Stack} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ @@ -94,12 +94,16 @@ object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper { * 1. Option[Expression]: optional changed condition after traversal * 2. Seq[(Expression, Literal)]: propagated mapping of expression => constant */ - private def traverse(expression: Expression): (Expression, Seq[(Expression, Literal)]) = + private def traverse(expression: Expression): (Expression, Map[Expression, Literal]) = expression match { - case e @ EqualTo(left, right: Literal) if e.deterministic => (e, Seq((left, right))) - case e @ EqualTo(left: Literal, right) if e.deterministic => (e, Seq((right, left))) - case e @ EqualNullSafe(left, right: Literal) if e.deterministic => (e, Seq((left, right))) - case e @ EqualNullSafe(left: Literal, right) if e.deterministic => (e, Seq((right, left))) + case e @ EqualTo(left, right: Literal) if e.deterministic => + (e, Map(left.canonicalized -> right)) + case e @ EqualTo(left: Literal, right) if e.deterministic => + (e, Map(right.canonicalized -> left)) + case e @ EqualNullSafe(left, right: Literal) if e.deterministic => + (e, Map(left.canonicalized -> right)) + case e @ EqualNullSafe(left: Literal, right) if e.deterministic => + (e, Map(right.canonicalized -> left)) case a @ And(left, right) => val (newLeft, equalityPredicatesLeft) = traverse(left) val replacedRight = replaceConstants(right, equalityPredicatesLeft) @@ -110,7 +114,7 @@ object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper { } else { And(replacedNewLeft, replacedNewRight) } - (newAnd, equalityPredicatesLeft ++ equalityPredicatesRight) + (newAnd, equalityPredicatesLeft ++= equalityPredicatesRight) case o @ Or(left, right) => // Ignore the EqualityPredicates from children since they are only propagated through And. val (newLeft, _) = traverse(left) @@ -121,7 +125,7 @@ object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper { Or(newLeft, newRight) } - (newOr, Seq.empty) + (newOr, Map.empty) case n @ Not(child) => // Ignore the EqualityPredicates from children since they are only propagated through And. val (newChild, _) = traverse(child) @@ -130,14 +134,16 @@ object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper { } else { Not(newChild) } - (newNot, Seq.empty) - case _ => (expression, Seq.empty) + (newNot, Map.empty) + case _ => (expression, Map.empty) } - private def replaceConstants(expression: Expression, constants: Seq[(Expression, Literal)]) = - constants.foldLeft(expression)((e, constant) => e transformUp { - case e if e.canonicalized == constant._1.canonicalized => constant._2 - }) + private def replaceConstants( + expression: Expression, + constants: Map[Expression, Literal]) = + expression transformUp { + case e if constants.contains(e.canonicalized) => constants(e.canonicalized) + } } /** From b0786e1da7a1a1fb095c6b1531c5a51bad7f9d18 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 8 May 2019 08:56:47 +0200 Subject: [PATCH 08/12] minor fix --- .../spark/sql/catalyst/optimizer/expressions.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 8eaed68ca7c8..e7674e06a409 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -92,17 +92,17 @@ object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper { * @param expression expression to be traversed * @return A tuple including: * 1. Option[Expression]: optional changed condition after traversal - * 2. Seq[(Expression, Literal)]: propagated mapping of expression => constant + * 2. Map[Expression, Literal]: propagated mapping of expression => constant */ private def traverse(expression: Expression): (Expression, Map[Expression, Literal]) = expression match { - case e @ EqualTo(left, right: Literal) if e.deterministic => + case e @ EqualTo(left, right: Literal) if !left.foldable && left.deterministic => (e, Map(left.canonicalized -> right)) - case e @ EqualTo(left: Literal, right) if e.deterministic => + case e @ EqualTo(left: Literal, right) if !right.foldable && right.deterministic => (e, Map(right.canonicalized -> left)) - case e @ EqualNullSafe(left, right: Literal) if e.deterministic => + case e @ EqualNullSafe(left, right: Literal) if !left.foldable && left.deterministic => (e, Map(left.canonicalized -> right)) - case e @ EqualNullSafe(left: Literal, right) if e.deterministic => + case e @ EqualNullSafe(left: Literal, right) if !right.foldable && right.deterministic => (e, Map(right.canonicalized -> left)) case a @ And(left, right) => val (newLeft, equalityPredicatesLeft) = traverse(left) From 693cdbb3d608a68df5c83c3a97ab64cf5f4cc35d Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 8 May 2019 21:33:37 +0200 Subject: [PATCH 09/12] remove constant propagation changes Change-Id: I8ba31ae5043184d53662a4cedfe011a36d63759f --- .../sql/catalyst/optimizer/expressions.scala | 333 +++++++++++++++--- .../plans/logical/QueryPlanConstraints.scala | 212 ----------- .../optimizer/ConstantPropagationSuite.scala | 7 +- 3 files changed, 281 insertions(+), 271 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index e7674e06a409..7329727c629e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet -import scala.collection.mutable.{ArrayBuffer, Map, Stack} +import scala.collection.mutable.{ArrayBuffer, Stack} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ @@ -55,95 +55,100 @@ object ConstantFolding extends Rule[LogicalPlan] { } /** - * Substitutes [[Expression Expressions]] which can be statically evaluated with their corresponding + * Substitutes [[Attribute Attributes]] which can be statically evaluated with their corresponding * value in conjunctive [[Expression Expressions]] * eg. * {{{ - * SELECT * FROM table WHERE i = 5 AND j = i + 3 => ... WHERE i = 5 AND j = 8 - * SELECT * FROM table WHERE abs(i) = 5 AND j <= abs(i) + 3 => ... WHERE abs(i) = 5 AND j <= 8 + * SELECT * FROM table WHERE i = 5 AND j = i + 3 + * ==> SELECT * FROM table WHERE i = 5 AND j = 8 * }}} * * Approach used: - * - Populate a mapping of expression => constant value by looking at all the deterministic equals - * predicates - * - Using this mapping, replace occurrence of the expressions with the corresponding constant - * values in the AND node. + * - Populate a mapping of attribute => constant value by looking at all the equals predicates + * - Using this mapping, replace occurrence of the attributes with the corresponding constant values + * in the AND node. */ object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case f: Filter => - val (newCondition, _) = traverse(f.condition) - if (newCondition fastEquals f.condition) { - f + val (newCondition, _) = traverse(f.condition, replaceChildren = true) + if (newCondition.isDefined) { + f.copy(condition = newCondition.get) } else { - f.copy(condition = newCondition) + f } } + type EqualityPredicates = Seq[((AttributeReference, Literal), BinaryComparison)] + /** - * Traverse a condition as a tree and replace expressions with constant values. + * Traverse a condition as a tree and replace attributes with constant values. * - On matching [[And]], recursively traverse each children and get propagated mappings. * If the current node is not child of another [[And]], replace all occurrences of the - * expressions with the corresponding constant values. + * attributes with the corresponding constant values. * - If a child of [[And]] is [[EqualTo]] or [[EqualNullSafe]], propagate the mapping - * of expression => constant. + * of attribute => constant. * - On matching [[Or]] or [[Not]], recursively traverse each children, propagate empty mapping. * - Otherwise, stop traversal and propagate empty mapping. - * @param expression expression to be traversed + * @param condition condition to be traversed + * @param replaceChildren whether to replace attributes with constant values in children * @return A tuple including: * 1. Option[Expression]: optional changed condition after traversal - * 2. Map[Expression, Literal]: propagated mapping of expression => constant + * 2. EqualityPredicates: propagated mapping of attribute => constant */ - private def traverse(expression: Expression): (Expression, Map[Expression, Literal]) = - expression match { - case e @ EqualTo(left, right: Literal) if !left.foldable && left.deterministic => - (e, Map(left.canonicalized -> right)) - case e @ EqualTo(left: Literal, right) if !right.foldable && right.deterministic => - (e, Map(right.canonicalized -> left)) - case e @ EqualNullSafe(left, right: Literal) if !left.foldable && left.deterministic => - (e, Map(left.canonicalized -> right)) - case e @ EqualNullSafe(left: Literal, right) if !right.foldable && right.deterministic => - (e, Map(right.canonicalized -> left)) - case a @ And(left, right) => - val (newLeft, equalityPredicatesLeft) = traverse(left) - val replacedRight = replaceConstants(right, equalityPredicatesLeft) - val (replacedNewRight, equalityPredicatesRight) = traverse(replacedRight) - val replacedNewLeft = replaceConstants(newLeft, equalityPredicatesRight) - val newAnd = if ((replacedNewLeft fastEquals left) && (replacedNewRight fastEquals right)) { - a + private def traverse(condition: Expression, replaceChildren: Boolean) + : (Option[Expression], EqualityPredicates) = + condition match { + case e @ EqualTo(left: AttributeReference, right: Literal) => (None, Seq(((left, right), e))) + case e @ EqualTo(left: Literal, right: AttributeReference) => (None, Seq(((right, left), e))) + case e @ EqualNullSafe(left: AttributeReference, right: Literal) => + (None, Seq(((left, right), e))) + case e @ EqualNullSafe(left: Literal, right: AttributeReference) => + (None, Seq(((right, left), e))) + case a: And => + val (newLeft, equalityPredicatesLeft) = traverse(a.left, replaceChildren = false) + val (newRight, equalityPredicatesRight) = traverse(a.right, replaceChildren = false) + val equalityPredicates = equalityPredicatesLeft ++ equalityPredicatesRight + val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) { + Some(And(replaceConstants(newLeft.getOrElse(a.left), equalityPredicates), + replaceConstants(newRight.getOrElse(a.right), equalityPredicates))) } else { - And(replacedNewLeft, replacedNewRight) + if (newLeft.isDefined || newRight.isDefined) { + Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right))) + } else { + None + } } - (newAnd, equalityPredicatesLeft ++= equalityPredicatesRight) - case o @ Or(left, right) => + (newSelf, equalityPredicates) + case o: Or => // Ignore the EqualityPredicates from children since they are only propagated through And. - val (newLeft, _) = traverse(left) - val (newRight, _) = traverse(right) - val newOr = if ((newLeft fastEquals left) && (newRight fastEquals right)) { - o + val (newLeft, _) = traverse(o.left, replaceChildren = true) + val (newRight, _) = traverse(o.right, replaceChildren = true) + val newSelf = if (newLeft.isDefined || newRight.isDefined) { + Some(Or(left = newLeft.getOrElse(o.left), right = newRight.getOrElse((o.right)))) } else { - Or(newLeft, newRight) + None } - - (newOr, Map.empty) - case n @ Not(child) => + (newSelf, Seq.empty) + case n: Not => // Ignore the EqualityPredicates from children since they are only propagated through And. - val (newChild, _) = traverse(child) - val newNot = if (newChild fastEquals child) { - n - } else { - Not(newChild) - } - (newNot, Map.empty) - case _ => (expression, Map.empty) + val (newChild, _) = traverse(n.child, replaceChildren = true) + (newChild.map(Not), Seq.empty) + case _ => (None, Seq.empty) } - private def replaceConstants( - expression: Expression, - constants: Map[Expression, Literal]) = - expression transformUp { - case e if constants.contains(e.canonicalized) => constants(e.canonicalized) + private def replaceConstants(condition: Expression, equalityPredicates: EqualityPredicates) + : Expression = { + val constantsMap = AttributeMap(equalityPredicates.map(_._1)) + val predicates = equalityPredicates.map(_._2).toSet + def replaceConstants0(expression: Expression) = expression transform { + case a: AttributeReference => constantsMap.getOrElse(a, a) + } + condition transform { + case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants0(e) + case e @ EqualNullSafe(_, _) if !predicates.contains(e) => replaceConstants0(e) } + } } /** @@ -164,6 +169,218 @@ object FilterReduction extends Rule[LogicalPlan] with ConstraintHelper { f.copy(condition = newCondition) } } + + private def normalizeAndReduceWithConstraints(expression: Expression): Expression = + reduceWithConstraints(normalize(expression))._1 + + private def normalize(expression: Expression) = expression transform { + case GreaterThan(x, y) => LessThan(y, x) + case GreaterThanOrEqual(x, y) => LessThanOrEqual(y, x) + } + + /** + * Traverse a condition as a tree and simplify expressions with constraints. + * - This functions assumes that the plan has been normalized using [[normalize()]] + * - On matching [[And]], recursively traverse both children, simplify child expressions with + * propagated constraints from sibling and propagate up union of constraints. + * - If a child of [[And]] is [[LessThan]], [[LessThanOrEqual]], [[EqualTo]], [[EqualNullSafe]], + * propagate the constraint. + * - On matching [[Or]] or [[Not]], recursively traverse each children, propagate no constraints. + * - Otherwise, stop traversal and propagate no constraints. + * @param expression expression to be traversed + * @return A tuple including: + * 1. Expression: optionally changed expression after traversal + * 2. Seq[Expression]: propagated constraints + */ + private def reduceWithConstraints(expression: Expression): (Expression, Seq[Expression]) = + expression match { + case e @ (_: LessThan | _: LessThanOrEqual | _: EqualTo | _: EqualNullSafe) + if e.deterministic => (e, Seq(e)) + case a @ And(left, right) => + val (newLeft, leftConstraints) = reduceWithConstraints(left) + val reducedRight = reduceWithConstraints(right, leftConstraints) + val (reducedNewRight, rightConstraints) = reduceWithConstraints(reducedRight) + val reducedNewLeft = reduceWithConstraints(newLeft, rightConstraints) + val newAnd = if ((reducedNewLeft fastEquals left) && + (reducedNewRight fastEquals right)) { + a + } else { + And(reducedNewLeft, reducedNewRight) + } + (newAnd, leftConstraints ++ rightConstraints) + case o @ Or(left, right) => + // Ignore the EqualityPredicates from children since they are only propagated through And. + val (newLeft, _) = reduceWithConstraints(left) + val (newRight, _) = reduceWithConstraints(right) + val newOr = if ((newLeft fastEquals left) && (newRight fastEquals right)) { + o + } else { + Or(newLeft, newRight) + } + + (newOr, Seq.empty) + case n @ Not(child) => + // Ignore the EqualityPredicates from children since they are only propagated through And. + val (newChild, _) = reduceWithConstraints(child) + val newNot = if (newChild fastEquals child) { + n + } else { + Not(newChild) + } + (newNot, Seq.empty) + case _ => (expression, Seq.empty) + } + + private def reduceWithConstraints(expression: Expression, constraints: Seq[Expression]) = + constraints.foldLeft(expression)((e, constraint) => reduceWithConstraint(e, constraint)) + + private def planEqual(x: Expression, y: Expression) = + !x.foldable && !y.foldable && x.canonicalized == y.canonicalized + + private def valueEqual(x: Expression, y: Expression) = + x.foldable && y.foldable && EqualTo(x, y).eval(EmptyRow).asInstanceOf[Boolean] + + private def valueLessThan(x: Expression, y: Expression) = + x.foldable && y.foldable && LessThan(x, y).eval(EmptyRow).asInstanceOf[Boolean] + + private def valueLessThanOrEqual(x: Expression, y: Expression) = + x.foldable && y.foldable && LessThanOrEqual(x, y).eval(EmptyRow).asInstanceOf[Boolean] + + private def reduceWithConstraint(expression: Expression, constraint: Expression): Expression = + constraint match { + case a LessThan b => expression transformUp { + case c LessThan d if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => + Literal.TrueLiteral + case c LessThan d if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => + Literal.FalseLiteral + case c LessThan d if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => + Literal.TrueLiteral + case c LessThan d if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => + Literal.FalseLiteral + + case c LessThanOrEqual d + if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => + Literal.TrueLiteral + case c LessThanOrEqual d + if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => + Literal.FalseLiteral + case c LessThanOrEqual d + if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => + Literal.TrueLiteral + case c LessThanOrEqual d + if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => + Literal.FalseLiteral + + case c EqualTo d if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => + Literal.FalseLiteral + case c EqualTo d if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => + Literal.FalseLiteral + case c EqualTo d if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => + Literal.FalseLiteral + case c EqualTo d if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => + Literal.FalseLiteral + + case c EqualNullSafe d + if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => + Literal.FalseLiteral + case c EqualNullSafe d + if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => + Literal.FalseLiteral + case c EqualNullSafe d + if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => + Literal.FalseLiteral + case c EqualNullSafe d + if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => + Literal.FalseLiteral + + case c EqualNullSafe d if planEqual(b, d) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(b, c) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, c) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, d) => EqualTo(c, d) + } + case a LessThanOrEqual b => expression transformUp { + case c LessThan d if planEqual(b, d) && valueLessThan(c, a) => + Literal.TrueLiteral + case c LessThan d if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => + Literal.FalseLiteral + case c LessThan d if planEqual(a, c) && valueLessThan(b, d) => + Literal.TrueLiteral + case c LessThan d if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => + Literal.FalseLiteral + + case c LessThanOrEqual d + if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => + Literal.TrueLiteral + case c LessThanOrEqual d if planEqual(b, c) && valueLessThan(d, a) => + Literal.FalseLiteral + case c LessThanOrEqual d if planEqual(b, c) && (planEqual(a, d) || valueEqual(a, d)) => + EqualTo(c, d) + case c LessThanOrEqual d + if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => + Literal.TrueLiteral + case c LessThanOrEqual d if planEqual(a, d) && valueLessThan(b, c) => + Literal.FalseLiteral + case c LessThanOrEqual d if planEqual(a, d) && (planEqual(b, c) || valueEqual(b, c)) => + EqualTo(c, d) + + case c EqualTo d if planEqual(b, d) && valueLessThan(c, a) => Literal.FalseLiteral + case c EqualTo d if planEqual(b, c) && valueLessThan(d, a) => Literal.FalseLiteral + case c EqualTo d if planEqual(a, c) && valueLessThan(b, d) => Literal.FalseLiteral + case c EqualTo d if planEqual(a, d) && valueLessThan(b, c) => Literal.FalseLiteral + + case c EqualNullSafe d if planEqual(b, d) && valueLessThan(c, a) => Literal.FalseLiteral + case c EqualNullSafe d if planEqual(b, c) && valueLessThan(d, a) => Literal.FalseLiteral + case c EqualNullSafe d if planEqual(a, c) && valueLessThan(b, d) => Literal.FalseLiteral + case c EqualNullSafe d if planEqual(a, d) && valueLessThan(b, c) => Literal.FalseLiteral + + case c EqualNullSafe d if planEqual(b, d) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(b, c) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, c) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, d) => EqualTo(c, d) + } + case a EqualTo b => expression transformUp { + case c LessThan d if planEqual(b, d) && planEqual(a, c) => Literal.FalseLiteral + case c LessThan d if planEqual(b, c) && planEqual(a, d) => Literal.FalseLiteral + case c LessThan d if planEqual(a, d) && planEqual(b, c) => Literal.FalseLiteral + case c LessThan d if planEqual(a, c) && planEqual(b, d) => Literal.FalseLiteral + + case c LessThanOrEqual d if planEqual(b, d) && planEqual(a, c) => Literal.TrueLiteral + case c LessThanOrEqual d if planEqual(b, c) && planEqual(a, d) => Literal.TrueLiteral + case c LessThanOrEqual d if planEqual(a, d) && planEqual(b, c) => Literal.TrueLiteral + case c LessThanOrEqual d if planEqual(a, c) && planEqual(b, d) => Literal.TrueLiteral + + case c EqualTo d if planEqual(b, d) && planEqual(a, c) => Literal.TrueLiteral + case c EqualTo d if planEqual(b, c) && planEqual(a, d) => Literal.TrueLiteral + case c EqualTo d if planEqual(a, d) && planEqual(b, c) => Literal.TrueLiteral + case c EqualTo d if planEqual(a, c) && planEqual(b, d) => Literal.TrueLiteral + + case c EqualNullSafe d if planEqual(b, d) => + if (planEqual(a, c)) Literal.TrueLiteral else EqualTo(c, d) + case c EqualNullSafe d if planEqual(b, c) => + if (planEqual(a, d)) Literal.TrueLiteral else EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, d) => + if (planEqual(b, c)) Literal.TrueLiteral else EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, c) => + if (planEqual(b, d)) Literal.TrueLiteral else EqualTo(c, d) + } + case a EqualNullSafe b => expression transformUp { + case c LessThan d if planEqual(b, d) && planEqual(a, c) => Literal.FalseLiteral + case c LessThan d if planEqual(b, c) && planEqual(d, a) => Literal.FalseLiteral + case c LessThan d if planEqual(a, d) && planEqual(b, c) => Literal.FalseLiteral + case c LessThan d if planEqual(a, c) && planEqual(d, b) => Literal.FalseLiteral + + case c LessThanOrEqual d if planEqual(b, d) && planEqual(a, c) => EqualTo(c, d) + case c LessThanOrEqual d if planEqual(b, c) && planEqual(a, d) => EqualTo(c, d) + case c LessThanOrEqual d if planEqual(a, d) && planEqual(b, c) => EqualTo(c, d) + case c LessThanOrEqual d if planEqual(a, c) && planEqual(b, d) => EqualTo(c, d) + + case c EqualNullSafe d if planEqual(b, d) && planEqual(a, c) => Literal.TrueLiteral + case c EqualNullSafe d if planEqual(b, c) && planEqual(a, d) => Literal.TrueLiteral + case c EqualNullSafe d if planEqual(a, d) && planEqual(b, c) => Literal.TrueLiteral + case c EqualNullSafe d if planEqual(a, c) && planEqual(b, d) => Literal.TrueLiteral + } + case _ => expression + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index e13899a4c98c..cc352c59dff8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -122,216 +122,4 @@ trait ConstraintHelper { case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute) case _ => Seq.empty[Attribute] } - - def normalizeAndReduceWithConstraints(expression: Expression): Expression = - reduceWithConstraints(normalize(expression))._1 - - private def normalize(expression: Expression) = expression transform { - case GreaterThan(x, y) => LessThan(y, x) - case GreaterThanOrEqual(x, y) => LessThanOrEqual(y, x) - } - - /** - * Traverse a condition as a tree and simplify expressions with constraints. - * - This functions assumes that the plan has been normalized using [[normalize()]] - * - On matching [[And]], recursively traverse both children, simplify child expressions with - * propagated constraints from sibling and propagate up union of constraints. - * - If a child of [[And]] is [[LessThan]], [[LessThanOrEqual]], [[EqualTo]], [[EqualNullSafe]], - * propagate the constraint. - * - On matching [[Or]] or [[Not]], recursively traverse each children, propagate no constraints. - * - Otherwise, stop traversal and propagate no constraints. - * @param expression expression to be traversed - * @return A tuple including: - * 1. Expression: optionally changed expression after traversal - * 2. Seq[Expression]: propagated constraints - */ - private def reduceWithConstraints(expression: Expression): (Expression, Seq[Expression]) = - expression match { - case e @ (_: LessThan | _: LessThanOrEqual | _: EqualTo | _: EqualNullSafe) - if e.deterministic => (e, Seq(e)) - case a @ And(left, right) => - val (newLeft, leftConstraints) = reduceWithConstraints(left) - val simplifiedRight = reduceWithConstraints(right, leftConstraints) - val (simplifiedNewRight, rightConstraints) = reduceWithConstraints(simplifiedRight) - val simplifiedNewLeft = reduceWithConstraints(newLeft, rightConstraints) - val newAnd = if ((simplifiedNewLeft fastEquals left) && - (simplifiedNewRight fastEquals right)) { - a - } else { - And(simplifiedNewLeft, simplifiedNewRight) - } - (newAnd, leftConstraints ++ rightConstraints) - case o @ Or(left, right) => - // Ignore the EqualityPredicates from children since they are only propagated through And. - val (newLeft, _) = reduceWithConstraints(left) - val (newRight, _) = reduceWithConstraints(right) - val newOr = if ((newLeft fastEquals left) && (newRight fastEquals right)) { - o - } else { - Or(newLeft, newRight) - } - - (newOr, Seq.empty) - case n @ Not(child) => - // Ignore the EqualityPredicates from children since they are only propagated through And. - val (newChild, _) = reduceWithConstraints(child) - val newNot = if (newChild fastEquals child) { - n - } else { - Not(newChild) - } - (newNot, Seq.empty) - case _ => (expression, Seq.empty) - } - - private def reduceWithConstraints(expression: Expression, constraints: Seq[Expression]) = - constraints.foldLeft(expression)((e, constraint) => reduceWithConstraint(e, constraint)) - - private def planEqual(x: Expression, y: Expression) = - !x.foldable && !y.foldable && x.canonicalized == y.canonicalized - - private def valueEqual(x: Expression, y: Expression) = - x.foldable && y.foldable && EqualTo(x, y).eval(EmptyRow).asInstanceOf[Boolean] - - private def valueLessThan(x: Expression, y: Expression) = - x.foldable && y.foldable && LessThan(x, y).eval(EmptyRow).asInstanceOf[Boolean] - - private def valueLessThanOrEqual(x: Expression, y: Expression) = - x.foldable && y.foldable && LessThanOrEqual(x, y).eval(EmptyRow).asInstanceOf[Boolean] - - private def reduceWithConstraint(expression: Expression, constraint: Expression): Expression = - constraint match { - case a LessThan b => expression transformUp { - case c LessThan d if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => - Literal.TrueLiteral - case c LessThan d if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => - Literal.FalseLiteral - case c LessThan d if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => - Literal.TrueLiteral - case c LessThan d if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => - Literal.FalseLiteral - - case c LessThanOrEqual d - if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => - Literal.TrueLiteral - case c LessThanOrEqual d - if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => - Literal.FalseLiteral - case c LessThanOrEqual d - if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => - Literal.TrueLiteral - case c LessThanOrEqual d - if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => - Literal.FalseLiteral - - case c EqualTo d if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => - Literal.FalseLiteral - case c EqualTo d if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => - Literal.FalseLiteral - case c EqualTo d if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => - Literal.FalseLiteral - case c EqualTo d if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => - Literal.FalseLiteral - - case c EqualNullSafe d - if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => - Literal.FalseLiteral - case c EqualNullSafe d - if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => - Literal.FalseLiteral - case c EqualNullSafe d - if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => - Literal.FalseLiteral - case c EqualNullSafe d - if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => - Literal.FalseLiteral - - case c EqualNullSafe d if planEqual(b, d) => EqualTo(c, d) - case c EqualNullSafe d if planEqual(b, c) => EqualTo(c, d) - case c EqualNullSafe d if planEqual(a, c) => EqualTo(c, d) - case c EqualNullSafe d if planEqual(a, d) => EqualTo(c, d) - } - case a LessThanOrEqual b => expression transformUp { - case c LessThan d if planEqual(b, d) && valueLessThan(c, a) => - Literal.TrueLiteral - case c LessThan d if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => - Literal.FalseLiteral - case c LessThan d if planEqual(a, c) && valueLessThan(b, d) => - Literal.TrueLiteral - case c LessThan d if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => - Literal.FalseLiteral - - case c LessThanOrEqual d - if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => - Literal.TrueLiteral - case c LessThanOrEqual d if planEqual(b, c) && valueLessThan(d, a) => - Literal.FalseLiteral - case c LessThanOrEqual d if planEqual(b, c) && (planEqual(a, d) || valueEqual(a, d)) => - EqualTo(c, d) - case c LessThanOrEqual d - if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => - Literal.TrueLiteral - case c LessThanOrEqual d if planEqual(a, d) && valueLessThan(b, c) => - Literal.FalseLiteral - case c LessThanOrEqual d if planEqual(a, d) && (planEqual(b, c) || valueEqual(b, c)) => - EqualTo(c, d) - - case c EqualTo d if planEqual(b, d) && valueLessThan(c, a) => Literal.FalseLiteral - case c EqualTo d if planEqual(b, c) && valueLessThan(d, a) => Literal.FalseLiteral - case c EqualTo d if planEqual(a, c) && valueLessThan(b, d) => Literal.FalseLiteral - case c EqualTo d if planEqual(a, d) && valueLessThan(b, c) => Literal.FalseLiteral - - case c EqualNullSafe d if planEqual(b, d) && valueLessThan(c, a) => Literal.FalseLiteral - case c EqualNullSafe d if planEqual(b, c) && valueLessThan(d, a) => Literal.FalseLiteral - case c EqualNullSafe d if planEqual(a, c) && valueLessThan(b, d) => Literal.FalseLiteral - case c EqualNullSafe d if planEqual(a, d) && valueLessThan(b, c) => Literal.FalseLiteral - - case c EqualNullSafe d if planEqual(b, d) => EqualTo(c, d) - case c EqualNullSafe d if planEqual(b, c) => EqualTo(c, d) - case c EqualNullSafe d if planEqual(a, c) => EqualTo(c, d) - case c EqualNullSafe d if planEqual(a, d) => EqualTo(c, d) - } - case a EqualTo b => expression transformUp { - case c LessThan d if planEqual(b, d) && planEqual(a, c) => Literal.FalseLiteral - case c LessThan d if planEqual(b, c) && planEqual(a, d) => Literal.FalseLiteral - case c LessThan d if planEqual(a, d) && planEqual(b, c) => Literal.FalseLiteral - case c LessThan d if planEqual(a, c) && planEqual(b, d) => Literal.FalseLiteral - - case c LessThanOrEqual d if planEqual(b, d) && planEqual(a, c) => Literal.TrueLiteral - case c LessThanOrEqual d if planEqual(b, c) && planEqual(a, d) => Literal.TrueLiteral - case c LessThanOrEqual d if planEqual(a, d) && planEqual(b, c) => Literal.TrueLiteral - case c LessThanOrEqual d if planEqual(a, c) && planEqual(b, d) => Literal.TrueLiteral - - case c EqualTo d if planEqual(b, d) && planEqual(a, c) => Literal.TrueLiteral - case c EqualTo d if planEqual(b, c) && planEqual(a, d) => Literal.TrueLiteral - case c EqualTo d if planEqual(a, d) && planEqual(b, c) => Literal.TrueLiteral - case c EqualTo d if planEqual(a, c) && planEqual(b, d) => Literal.TrueLiteral - - case c EqualNullSafe d if planEqual(b, d) => - if (planEqual(a, c)) Literal.TrueLiteral else EqualTo(c, d) - case c EqualNullSafe d if planEqual(b, c) => - if (planEqual(a, d)) Literal.TrueLiteral else EqualTo(c, d) - case c EqualNullSafe d if planEqual(a, d) => - if (planEqual(b, c)) Literal.TrueLiteral else EqualTo(c, d) - case c EqualNullSafe d if planEqual(a, c) => - if (planEqual(b, d)) Literal.TrueLiteral else EqualTo(c, d) - } - case a EqualNullSafe b => expression transformUp { - case c LessThan d if planEqual(b, d) && planEqual(a, c) => Literal.FalseLiteral - case c LessThan d if planEqual(b, c) && planEqual(d, a) => Literal.FalseLiteral - case c LessThan d if planEqual(a, d) && planEqual(b, c) => Literal.FalseLiteral - case c LessThan d if planEqual(a, c) && planEqual(d, b) => Literal.FalseLiteral - - case c LessThanOrEqual d if planEqual(b, d) && planEqual(a, c) => EqualTo(c, d) - case c LessThanOrEqual d if planEqual(b, c) && planEqual(a, d) => EqualTo(c, d) - case c LessThanOrEqual d if planEqual(a, d) && planEqual(b, c) => EqualTo(c, d) - case c LessThanOrEqual d if planEqual(a, c) && planEqual(b, d) => EqualTo(c, d) - - case c EqualNullSafe d if planEqual(b, d) && planEqual(a, c) => Literal.TrueLiteral - case c EqualNullSafe d if planEqual(b, c) && planEqual(a, d) => Literal.TrueLiteral - case c EqualNullSafe d if planEqual(a, d) && planEqual(b, c) => Literal.TrueLiteral - case c EqualNullSafe d if planEqual(a, c) && planEqual(b, d) => Literal.TrueLiteral - } - case _ => expression - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala index bbf390b7ddf6..94174eec8fd0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala @@ -106,12 +106,14 @@ class ConstantPropagationSuite extends PlanTest { test("equality predicates outside a `OR` can be propagated within a `OR`") { val query = testRelation + .select(columnA) .where( columnA === Literal(2) && (columnA === Add(columnB, Literal(3)) || columnB === Literal(9))) .analyze val correctAnswer = testRelation + .select(columnA) .where( columnA === Literal(2) && (Literal(2) === Add(columnB, Literal(3)) || columnB === Literal(9))) @@ -151,10 +153,13 @@ class ConstantPropagationSuite extends PlanTest { test("conflicting equality predicates") { val query = testRelation + .select(columnA) .where( columnA === Literal(1) && columnA === Literal(2) && columnB === Add(columnA, Literal(3))) - val correctAnswer = testRelation.where(Literal.FalseLiteral).analyze + val correctAnswer = testRelation + .select(columnA) + .where(columnA === Literal(1) && columnA === Literal(2) && columnB === Literal(5)).analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } From 9dc1422f2b8244811e7861ac0f767039a98e3548 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 9 May 2019 08:30:40 +0200 Subject: [PATCH 10/12] fix review findings Change-Id: Ifef383125dd8827fdeea90e8101cb432de9e704e --- .../optimizer/FilterReductionSuite.scala | 509 +++++++++--------- 1 file changed, 256 insertions(+), 253 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterReductionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterReductionSuite.scala index 3e711ca39ac5..85c5a2cb4104 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterReductionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterReductionSuite.scala @@ -59,7 +59,7 @@ class FilterReductionSuite extends PlanTest { val testRelationWithData = LocalRelation.fromExternalRows(testRelation.output, data.map(Row(_))) - private def testPropagation( + private def testFilterReduction( input: Expression, expectEmptyRelation: Boolean, expectedConstraints: Seq[Expression] = Seq.empty) = { @@ -74,280 +74,283 @@ class FilterReductionSuite extends PlanTest { } test("Filter reduction") { - testPropagation('a < 2 && Literal(2) < 'a, true) - testPropagation('a < 2 && Literal(2) <= 'a, true) - testPropagation('a < 2 && Literal(2) === 'a, true) - testPropagation('a < 2 && Literal(2) <=> 'a, true) - testPropagation('a < 2 && Literal(2) >= 'a, false, Seq('a < 2)) - testPropagation('a < 2 && Literal(2) > 'a, false, Seq('a < 2)) - testPropagation('a <= 2 && Literal(2) < 'a, true) - testPropagation('a <= 2 && Literal(2) <= 'a, false, Seq('a === 2)) - testPropagation('a <= 2 && Literal(2) === 'a, false, Seq('a === 2)) - testPropagation('a <= 2 && Literal(2) <=> 'a, false, Seq('a <=> 2)) - testPropagation('a <= 2 && Literal(2) >= 'a, false, Seq('a <= 2)) - testPropagation('a <= 2 && Literal(2) > 'a, false, Seq('a < 2)) - testPropagation('a === 2 && Literal(2) < 'a, true) - testPropagation('a === 2 && Literal(2) <= 'a, false, Seq('a === 2)) - testPropagation('a === 2 && Literal(2) === 'a, false, Seq('a === 2)) - testPropagation('a === 2 && Literal(2) <=> 'a, false, Seq('a === 2)) - testPropagation('a === 2 && Literal(2) >= 'a, false, Seq('a === 2)) - testPropagation('a === 2 && Literal(2) > 'a, true) - testPropagation('a <=> 2 && Literal(2) < 'a, true) - testPropagation('a <=> 2 && Literal(2) <= 'a, false, Seq('a <=> 2)) - testPropagation('a <=> 2 && Literal(2) === 'a, false, Seq('a <=> 2)) - testPropagation('a <=> 2 && Literal(2) <=> 'a, false, Seq('a <=> 2)) - testPropagation('a <=> 2 && Literal(2) >= 'a, false, Seq('a <=> 2)) - testPropagation('a <=> 2 && Literal(2) > 'a, true) - testPropagation('a >= 2 && Literal(2) < 'a, false, Seq(Literal(2) < 'a)) - testPropagation('a >= 2 && Literal(2) <= 'a, false, Seq(Literal(2) <= 'a)) - testPropagation('a >= 2 && Literal(2) === 'a, false, Seq('a === 2)) - testPropagation('a >= 2 && Literal(2) <=> 'a, false, Seq('a <=> 2)) - testPropagation('a >= 2 && Literal(2) >= 'a, false, Seq('a === 2)) - testPropagation('a >= 2 && Literal(2) > 'a, true) - testPropagation('a > 2 && Literal(2) < 'a, false, Seq(Literal(2) < 'a)) - testPropagation('a > 2 && Literal(2) <= 'a, false, Seq(Literal(2) < 'a)) - testPropagation('a > 2 && Literal(2) === 'a, true) - testPropagation('a > 2 && Literal(2) <=> 'a, true) - testPropagation('a > 2 && Literal(2) >= 'a, true) - testPropagation('a > 2 && Literal(2) > 'a, true) + testFilterReduction('a < 2 && Literal(2) < 'a, true) + testFilterReduction('a < 2 && Literal(2) <= 'a, true) + testFilterReduction('a < 2 && Literal(2) === 'a, true) + testFilterReduction('a < 2 && Literal(2) <=> 'a, true) + testFilterReduction('a < 2 && Literal(2) >= 'a, false, Seq('a < 2)) + testFilterReduction('a < 2 && Literal(2) > 'a, false, Seq('a < 2)) + testFilterReduction('a <= 2 && Literal(2) < 'a, true) + testFilterReduction('a <= 2 && Literal(2) <= 'a, false, Seq('a === 2)) + testFilterReduction('a <= 2 && Literal(2) === 'a, false, Seq('a === 2)) + testFilterReduction('a <= 2 && Literal(2) <=> 'a, false, Seq('a <=> 2)) + testFilterReduction('a <= 2 && Literal(2) >= 'a, false, Seq('a <= 2)) + testFilterReduction('a <= 2 && Literal(2) > 'a, false, Seq('a < 2)) + testFilterReduction('a === 2 && Literal(2) < 'a, true) + testFilterReduction('a === 2 && Literal(2) <= 'a, false, Seq('a === 2)) + testFilterReduction('a === 2 && Literal(2) === 'a, false, Seq('a === 2)) + testFilterReduction('a === 2 && Literal(2) <=> 'a, false, Seq('a === 2)) + testFilterReduction('a === 2 && Literal(2) >= 'a, false, Seq('a === 2)) + testFilterReduction('a === 2 && Literal(2) > 'a, true) + testFilterReduction('a <=> 2 && Literal(2) < 'a, true) + testFilterReduction('a <=> 2 && Literal(2) <= 'a, false, Seq('a <=> 2)) + testFilterReduction('a <=> 2 && Literal(2) === 'a, false, Seq('a <=> 2)) + testFilterReduction('a <=> 2 && Literal(2) <=> 'a, false, Seq('a <=> 2)) + testFilterReduction('a <=> 2 && Literal(2) >= 'a, false, Seq('a <=> 2)) + testFilterReduction('a <=> 2 && Literal(2) > 'a, true) + testFilterReduction('a >= 2 && Literal(2) < 'a, false, Seq(Literal(2) < 'a)) + testFilterReduction('a >= 2 && Literal(2) <= 'a, false, Seq(Literal(2) <= 'a)) + testFilterReduction('a >= 2 && Literal(2) === 'a, false, Seq('a === 2)) + testFilterReduction('a >= 2 && Literal(2) <=> 'a, false, Seq('a <=> 2)) + testFilterReduction('a >= 2 && Literal(2) >= 'a, false, Seq('a === 2)) + testFilterReduction('a >= 2 && Literal(2) > 'a, true) + testFilterReduction('a > 2 && Literal(2) < 'a, false, Seq(Literal(2) < 'a)) + testFilterReduction('a > 2 && Literal(2) <= 'a, false, Seq(Literal(2) < 'a)) + testFilterReduction('a > 2 && Literal(2) === 'a, true) + testFilterReduction('a > 2 && Literal(2) <=> 'a, true) + testFilterReduction('a > 2 && Literal(2) >= 'a, true) + testFilterReduction('a > 2 && Literal(2) > 'a, true) - testPropagation(('x || 'a < 2) && Literal(2) < 'a, false, Seq('x, Literal(2) < 'a)) - testPropagation(('x || 'a < 2) && Literal(2) <= 'a, false, Seq('x, Literal(2) <= 'a)) - testPropagation(('x || 'a < 2) && Literal(2) === 'a, false, Seq('x, Literal(2) === 'a)) - testPropagation(('x || 'a < 2) && Literal(2) <=> 'a, false, Seq('x, Literal(2) <=> 'a)) - testPropagation(('x || 'a < 2) && Literal(2) >= 'a, false, Seq('x || 'a < 2, 'a <= 2)) - testPropagation(('x || 'a < 2) && Literal(2) > 'a, false, Seq('a < 2)) - testPropagation(('x || 'a <= 2) && Literal(2) < 'a, false, Seq('x, Literal(2) < 'a)) - testPropagation(('x || 'a <= 2) && Literal(2) <= 'a, false, + testFilterReduction(('x || 'a < 2) && Literal(2) < 'a, false, Seq('x, Literal(2) < 'a)) + testFilterReduction(('x || 'a < 2) && Literal(2) <= 'a, false, Seq('x, Literal(2) <= 'a)) + testFilterReduction(('x || 'a < 2) && Literal(2) === 'a, false, Seq('x, Literal(2) === 'a)) + testFilterReduction(('x || 'a < 2) && Literal(2) <=> 'a, false, Seq('x, Literal(2) <=> 'a)) + testFilterReduction(('x || 'a < 2) && Literal(2) >= 'a, false, Seq('x || 'a < 2, 'a <= 2)) + testFilterReduction(('x || 'a < 2) && Literal(2) > 'a, false, Seq('a < 2)) + testFilterReduction(('x || 'a <= 2) && Literal(2) < 'a, false, Seq('x, Literal(2) < 'a)) + testFilterReduction(('x || 'a <= 2) && Literal(2) <= 'a, false, Seq('x || 'a === 2, Literal(2) <= 'a)) - testPropagation(('x || 'a <= 2) && Literal(2) === 'a, false, Seq(Literal(2) === 'a)) - testPropagation(('x || 'a <= 2) && Literal(2) <=> 'a, false, Seq(Literal(2) <=> 'a)) - testPropagation(('x || 'a <= 2) && Literal(2) >= 'a, false, Seq('a <= 2)) - testPropagation(('x || 'a <= 2) && Literal(2) > 'a, false, Seq('a < 2)) - testPropagation(('x || 'a === 2) && Literal(2) < 'a, false, Seq('x, Literal(2) < 'a)) - testPropagation(('x || 'a === 2) && Literal(2) <= 'a, false, + testFilterReduction(('x || 'a <= 2) && Literal(2) === 'a, false, Seq(Literal(2) === 'a)) + testFilterReduction(('x || 'a <= 2) && Literal(2) <=> 'a, false, Seq(Literal(2) <=> 'a)) + testFilterReduction(('x || 'a <= 2) && Literal(2) >= 'a, false, Seq('a <= 2)) + testFilterReduction(('x || 'a <= 2) && Literal(2) > 'a, false, Seq('a < 2)) + testFilterReduction(('x || 'a === 2) && Literal(2) < 'a, false, Seq('x, Literal(2) < 'a)) + testFilterReduction(('x || 'a === 2) && Literal(2) <= 'a, false, Seq('x || 'a === 2, Literal(2) <= 'a)) - testPropagation(('x || 'a === 2) && Literal(2) === 'a, false, Seq(Literal(2) === 'a)) - testPropagation(('x || 'a === 2) && Literal(2) <=> 'a, false, Seq(Literal(2) <=> 'a)) - testPropagation(('x || 'a === 2) && Literal(2) >= 'a, false, Seq('x || 'a === 2, 'a <= 2)) - testPropagation(('x || 'a === 2) && Literal(2) > 'a, false, Seq('x, 'a < 2)) - testPropagation(('x || 'a <=> 2) && Literal(2) < 'a, false, Seq('x, Literal(2) < 'a)) - testPropagation(('x || 'a <=> 2) && Literal(2) <= 'a, false, + testFilterReduction(('x || 'a === 2) && Literal(2) === 'a, false, Seq(Literal(2) === 'a)) + testFilterReduction(('x || 'a === 2) && Literal(2) <=> 'a, false, Seq(Literal(2) <=> 'a)) + testFilterReduction(('x || 'a === 2) && Literal(2) >= 'a, false, Seq('x || 'a === 2, 'a <= 2)) + testFilterReduction(('x || 'a === 2) && Literal(2) > 'a, false, Seq('x, 'a < 2)) + testFilterReduction(('x || 'a <=> 2) && Literal(2) < 'a, false, Seq('x, Literal(2) < 'a)) + testFilterReduction(('x || 'a <=> 2) && Literal(2) <= 'a, false, Seq('x || 'a === 2, Literal(2) <= 'a)) - testPropagation(('x || 'a <=> 2) && Literal(2) === 'a, false, Seq(Literal(2) === 'a)) - testPropagation(('x || 'a <=> 2) && Literal(2) <=> 'a, false, Seq(Literal(2) <=> 'a)) - testPropagation(('x || 'a <=> 2) && Literal(2) >= 'a, false, Seq('x || 'a === 2, 'a <= 2)) - testPropagation(('x || 'a <=> 2) && Literal(2) > 'a, false, Seq('x, 'a < 2)) - testPropagation(('x || 'a >= 2) && Literal(2) < 'a, false, Seq(Literal(2) < 'a)) - testPropagation(('x || 'a >= 2) && Literal(2) <= 'a, false, Seq(Literal(2) <= 'a)) - testPropagation(('x || 'a >= 2) && Literal(2) === 'a, false, Seq(Literal(2) === 'a)) - testPropagation(('x || 'a >= 2) && Literal(2) <=> 'a, false, Seq(Literal(2) <=> 'a)) - testPropagation(('x || 'a >= 2) && Literal(2) >= 'a, false, + testFilterReduction(('x || 'a <=> 2) && Literal(2) === 'a, false, Seq(Literal(2) === 'a)) + testFilterReduction(('x || 'a <=> 2) && Literal(2) <=> 'a, false, Seq(Literal(2) <=> 'a)) + testFilterReduction(('x || 'a <=> 2) && Literal(2) >= 'a, false, Seq('x || 'a === 2, 'a <= 2)) + testFilterReduction(('x || 'a <=> 2) && Literal(2) > 'a, false, Seq('x, 'a < 2)) + testFilterReduction(('x || 'a >= 2) && Literal(2) < 'a, false, Seq(Literal(2) < 'a)) + testFilterReduction(('x || 'a >= 2) && Literal(2) <= 'a, false, Seq(Literal(2) <= 'a)) + testFilterReduction(('x || 'a >= 2) && Literal(2) === 'a, false, Seq(Literal(2) === 'a)) + testFilterReduction(('x || 'a >= 2) && Literal(2) <=> 'a, false, Seq(Literal(2) <=> 'a)) + testFilterReduction(('x || 'a >= 2) && Literal(2) >= 'a, false, Seq('x || 'a === 2, 'a <= Literal(2))) - testPropagation(('x || 'a >= 2) && Literal(2) > 'a, false, Seq('x, 'a < Literal(2))) - testPropagation(('x || 'a > 2) && Literal(2) < 'a, false, Seq(Literal(2) < 'a)) - testPropagation(('x || 'a > 2) && Literal(2) <= 'a, false, + testFilterReduction(('x || 'a >= 2) && Literal(2) > 'a, false, Seq('x, 'a < Literal(2))) + testFilterReduction(('x || 'a > 2) && Literal(2) < 'a, false, Seq(Literal(2) < 'a)) + testFilterReduction(('x || 'a > 2) && Literal(2) <= 'a, false, Seq('x || Literal(2) < 'a, Literal(2) <= 'a)) - testPropagation(('x || 'a > 2) && Literal(2) === 'a, false, Seq('x, Literal(2) === 'a)) - testPropagation(('x || 'a > 2) && Literal(2) <=> 'a, false, Seq('x, Literal(2) <=> 'a)) - testPropagation(('x || 'a > 2) && Literal(2) >= 'a, false, Seq('x, 'a <= Literal(2))) - testPropagation(('x || 'a > 2) && Literal(2) > 'a, false, Seq('x, 'a < Literal(2))) + testFilterReduction(('x || 'a > 2) && Literal(2) === 'a, false, Seq('x, Literal(2) === 'a)) + testFilterReduction(('x || 'a > 2) && Literal(2) <=> 'a, false, Seq('x, Literal(2) <=> 'a)) + testFilterReduction(('x || 'a > 2) && Literal(2) >= 'a, false, Seq('x, 'a <= Literal(2))) + testFilterReduction(('x || 'a > 2) && Literal(2) > 'a, false, Seq('x, 'a < Literal(2))) - testPropagation('a < 2 && Literal(3) < 'a, true) - testPropagation('a < 2 && Literal(3) <= 'a, true) - testPropagation('a < 2 && Literal(3) === 'a, true) - testPropagation('a < 2 && Literal(3) <=> 'a, true) - testPropagation('a < 2 && Literal(3) >= 'a, false, Seq('a < 2)) - testPropagation('a < 2 && Literal(3) > 'a, false, Seq('a < 2)) - testPropagation('a <= 2 && Literal(3) < 'a, true) - testPropagation('a <= 2 && Literal(3) <= 'a, true) - testPropagation('a <= 2 && Literal(3) === 'a, true) - testPropagation('a <= 2 && Literal(3) <=> 'a, true) - testPropagation('a <= 2 && Literal(3) >= 'a, false, Seq('a <= 2)) - testPropagation('a <= 2 && Literal(3) > 'a, false, Seq('a <= 2)) - testPropagation('a === 2 && Literal(3) < 'a, true) - testPropagation('a === 2 && Literal(3) <= 'a, true) - testPropagation('a === 2 && Literal(3) === 'a, true) - testPropagation('a === 2 && Literal(3) <=> 'a, true) - testPropagation('a === 2 && Literal(3) >= 'a, false, Seq('a === 2)) - testPropagation('a === 2 && Literal(3) > 'a, false, Seq('a === 2)) - testPropagation('a <=> 2 && Literal(3) < 'a, true) - testPropagation('a <=> 2 && Literal(3) <= 'a, true) - testPropagation('a <=> 2 && Literal(3) === 'a, true) - testPropagation('a <=> 2 && Literal(3) <=> 'a, true) - testPropagation('a <=> 2 && Literal(3) >= 'a, false, Seq('a <=> 2)) - testPropagation('a <=> 2 && Literal(3) > 'a, false, Seq('a <=> 2)) - testPropagation('a >= 2 && Literal(3) < 'a, false, Seq(Literal(3) < 'a)) - testPropagation('a >= 2 && Literal(3) <= 'a, false, Seq(Literal(3) <= 'a)) - testPropagation('a >= 2 && Literal(3) === 'a, false, Seq(Literal(3) === 'a)) - testPropagation('a >= 2 && Literal(3) <=> 'a, false, Seq(Literal(3) <=> 'a)) - testPropagation('a >= 2 && Literal(3) >= 'a, false, Seq(Literal(2) <= 'a, 'a <= 3)) - testPropagation('a >= 2 && Literal(3) > 'a, false, Seq(Literal(2) <= 'a, 'a < 3)) - testPropagation('a > 2 && Literal(3) < 'a, false, Seq(Literal(3) < 'a)) - testPropagation('a > 2 && Literal(3) <= 'a, false, Seq(Literal(3) <= 'a)) - testPropagation('a > 2 && Literal(3) === 'a, false, Seq(Literal(3) === 'a)) - testPropagation('a > 2 && Literal(3) <=> 'a, false, Seq(Literal(3) <=> 'a)) - testPropagation('a > 2 && Literal(3) >= 'a, false, Seq(Literal(2) < 'a, 'a <= 3)) - testPropagation('a > 2 && Literal(3) > 'a, false, Seq(Literal(2) < 'a, 'a < 3)) + testFilterReduction('a < 2 && Literal(3) < 'a, true) + testFilterReduction('a < 2 && Literal(3) <= 'a, true) + testFilterReduction('a < 2 && Literal(3) === 'a, true) + testFilterReduction('a < 2 && Literal(3) <=> 'a, true) + testFilterReduction('a < 2 && Literal(3) >= 'a, false, Seq('a < 2)) + testFilterReduction('a < 2 && Literal(3) > 'a, false, Seq('a < 2)) + testFilterReduction('a <= 2 && Literal(3) < 'a, true) + testFilterReduction('a <= 2 && Literal(3) <= 'a, true) + testFilterReduction('a <= 2 && Literal(3) === 'a, true) + testFilterReduction('a <= 2 && Literal(3) <=> 'a, true) + testFilterReduction('a <= 2 && Literal(3) >= 'a, false, Seq('a <= 2)) + testFilterReduction('a <= 2 && Literal(3) > 'a, false, Seq('a <= 2)) + testFilterReduction('a === 2 && Literal(3) < 'a, true) + testFilterReduction('a === 2 && Literal(3) <= 'a, true) + testFilterReduction('a === 2 && Literal(3) === 'a, true) + testFilterReduction('a === 2 && Literal(3) <=> 'a, true) + testFilterReduction('a === 2 && Literal(3) >= 'a, false, Seq('a === 2)) + testFilterReduction('a === 2 && Literal(3) > 'a, false, Seq('a === 2)) + testFilterReduction('a <=> 2 && Literal(3) < 'a, true) + testFilterReduction('a <=> 2 && Literal(3) <= 'a, true) + testFilterReduction('a <=> 2 && Literal(3) === 'a, true) + testFilterReduction('a <=> 2 && Literal(3) <=> 'a, true) + testFilterReduction('a <=> 2 && Literal(3) >= 'a, false, Seq('a <=> 2)) + testFilterReduction('a <=> 2 && Literal(3) > 'a, false, Seq('a <=> 2)) + testFilterReduction('a >= 2 && Literal(3) < 'a, false, Seq(Literal(3) < 'a)) + testFilterReduction('a >= 2 && Literal(3) <= 'a, false, Seq(Literal(3) <= 'a)) + testFilterReduction('a >= 2 && Literal(3) === 'a, false, Seq(Literal(3) === 'a)) + testFilterReduction('a >= 2 && Literal(3) <=> 'a, false, Seq(Literal(3) <=> 'a)) + testFilterReduction('a >= 2 && Literal(3) >= 'a, false, Seq(Literal(2) <= 'a, 'a <= 3)) + testFilterReduction('a >= 2 && Literal(3) > 'a, false, Seq(Literal(2) <= 'a, 'a < 3)) + testFilterReduction('a > 2 && Literal(3) < 'a, false, Seq(Literal(3) < 'a)) + testFilterReduction('a > 2 && Literal(3) <= 'a, false, Seq(Literal(3) <= 'a)) + testFilterReduction('a > 2 && Literal(3) === 'a, false, Seq(Literal(3) === 'a)) + testFilterReduction('a > 2 && Literal(3) <=> 'a, false, Seq(Literal(3) <=> 'a)) + testFilterReduction('a > 2 && Literal(3) >= 'a, false, Seq(Literal(2) < 'a, 'a <= 3)) + testFilterReduction('a > 2 && Literal(3) > 'a, false, Seq(Literal(2) < 'a, 'a < 3)) - testPropagation(('x || 'a < 2) && Literal(3) < 'a, false, Seq('x, Literal(3) < 'a)) - testPropagation(('x || 'a < 2) && Literal(3) <= 'a, false, Seq('x, Literal(3) <= 'a)) - testPropagation(('x || 'a < 2) && Literal(3) === 'a, false, Seq('x, Literal(3) === 'a)) - testPropagation(('x || 'a < 2) && Literal(3) <=> 'a, false, Seq('x, Literal(3) <=> 'a)) - testPropagation(('x || 'a < 2) && Literal(3) >= 'a, false, Seq('x || 'a < 2, 'a <= 3)) - testPropagation(('x || 'a < 2) && Literal(3) > 'a, false, Seq('x || 'a < 2, 'a < 3)) - testPropagation(('x || 'a <= 2) && Literal(3) < 'a, false, Seq('x, Literal(3) < 'a)) - testPropagation(('x || 'a <= 2) && Literal(3) <= 'a, false, Seq('x, Literal(3) <= 'a)) - testPropagation(('x || 'a <= 2) && Literal(3) === 'a, false, Seq('x, Literal(3) === 'a)) - testPropagation(('x || 'a <= 2) && Literal(3) <=> 'a, false, Seq('x, Literal(3) <=> 'a)) - testPropagation(('x || 'a <= 2) && Literal(3) >= 'a, false, Seq('x || 'a <= 2, 'a <= 3)) - testPropagation(('x || 'a <= 2) && Literal(3) > 'a, false, Seq('x || 'a <= 2, 'a < 3)) - testPropagation(('x || 'a === 2) && Literal(3) < 'a, false, Seq('x, Literal(3) < 'a)) - testPropagation(('x || 'a === 2) && Literal(3) <= 'a, false, Seq('x, Literal(3) <= 'a)) - testPropagation(('x || 'a === 2) && Literal(3) === 'a, false, Seq('x, Literal(3) === 'a)) - testPropagation(('x || 'a === 2) && Literal(3) <=> 'a, false, Seq('x, Literal(3) <=> 'a)) - testPropagation(('x || 'a === 2) && Literal(3) >= 'a, false, Seq('x || 'a === 2, 'a <= 3)) - testPropagation(('x || 'a === 2) && Literal(3) > 'a, false, Seq('x || 'a === 2, 'a < 3)) - testPropagation(('x || 'a <=> 2) && Literal(3) < 'a, false, Seq('x, Literal(3) < 'a)) - testPropagation(('x || 'a <=> 2) && Literal(3) <= 'a, false, Seq('x, Literal(3) <= 'a)) - testPropagation(('x || 'a <=> 2) && Literal(3) === 'a, false, Seq('x, Literal(3) === 'a)) - testPropagation(('x || 'a <=> 2) && Literal(3) <=> 'a, false, Seq('x, Literal(3) <=> 'a)) - testPropagation(('x || 'a <=> 2) && Literal(3) >= 'a, false, Seq('x || 'a === 2, 'a <= 3)) - testPropagation(('x || 'a <=> 2) && Literal(3) > 'a, false, Seq('x || 'a === 2, 'a < 3)) - testPropagation(('x || 'a >= 2) && Literal(3) < 'a, false, Seq(Literal(3) < 'a)) - testPropagation(('x || 'a >= 2) && Literal(3) <= 'a, false, Seq(Literal(3) <= 'a)) - testPropagation(('x || 'a >= 2) && Literal(3) === 'a, false, Seq(Literal(3) === 'a)) - testPropagation(('x || 'a >= 2) && Literal(3) <=> 'a, false, Seq(Literal(3) <=> 'a)) - testPropagation(('x || 'a >= 2) && Literal(3) >= 'a, false, + testFilterReduction(('x || 'a < 2) && Literal(3) < 'a, false, Seq('x, Literal(3) < 'a)) + testFilterReduction(('x || 'a < 2) && Literal(3) <= 'a, false, Seq('x, Literal(3) <= 'a)) + testFilterReduction(('x || 'a < 2) && Literal(3) === 'a, false, Seq('x, Literal(3) === 'a)) + testFilterReduction(('x || 'a < 2) && Literal(3) <=> 'a, false, Seq('x, Literal(3) <=> 'a)) + testFilterReduction(('x || 'a < 2) && Literal(3) >= 'a, false, Seq('x || 'a < 2, 'a <= 3)) + testFilterReduction(('x || 'a < 2) && Literal(3) > 'a, false, Seq('x || 'a < 2, 'a < 3)) + testFilterReduction(('x || 'a <= 2) && Literal(3) < 'a, false, Seq('x, Literal(3) < 'a)) + testFilterReduction(('x || 'a <= 2) && Literal(3) <= 'a, false, Seq('x, Literal(3) <= 'a)) + testFilterReduction(('x || 'a <= 2) && Literal(3) === 'a, false, Seq('x, Literal(3) === 'a)) + testFilterReduction(('x || 'a <= 2) && Literal(3) <=> 'a, false, Seq('x, Literal(3) <=> 'a)) + testFilterReduction(('x || 'a <= 2) && Literal(3) >= 'a, false, Seq('x || 'a <= 2, 'a <= 3)) + testFilterReduction(('x || 'a <= 2) && Literal(3) > 'a, false, Seq('x || 'a <= 2, 'a < 3)) + testFilterReduction(('x || 'a === 2) && Literal(3) < 'a, false, Seq('x, Literal(3) < 'a)) + testFilterReduction(('x || 'a === 2) && Literal(3) <= 'a, false, Seq('x, Literal(3) <= 'a)) + testFilterReduction(('x || 'a === 2) && Literal(3) === 'a, false, Seq('x, Literal(3) === 'a)) + testFilterReduction(('x || 'a === 2) && Literal(3) <=> 'a, false, Seq('x, Literal(3) <=> 'a)) + testFilterReduction(('x || 'a === 2) && Literal(3) >= 'a, false, Seq('x || 'a === 2, 'a <= 3)) + testFilterReduction(('x || 'a === 2) && Literal(3) > 'a, false, Seq('x || 'a === 2, 'a < 3)) + testFilterReduction(('x || 'a <=> 2) && Literal(3) < 'a, false, Seq('x, Literal(3) < 'a)) + testFilterReduction(('x || 'a <=> 2) && Literal(3) <= 'a, false, Seq('x, Literal(3) <= 'a)) + testFilterReduction(('x || 'a <=> 2) && Literal(3) === 'a, false, Seq('x, Literal(3) === 'a)) + testFilterReduction(('x || 'a <=> 2) && Literal(3) <=> 'a, false, Seq('x, Literal(3) <=> 'a)) + testFilterReduction(('x || 'a <=> 2) && Literal(3) >= 'a, false, Seq('x || 'a === 2, 'a <= 3)) + testFilterReduction(('x || 'a <=> 2) && Literal(3) > 'a, false, Seq('x || 'a === 2, 'a < 3)) + testFilterReduction(('x || 'a >= 2) && Literal(3) < 'a, false, Seq(Literal(3) < 'a)) + testFilterReduction(('x || 'a >= 2) && Literal(3) <= 'a, false, Seq(Literal(3) <= 'a)) + testFilterReduction(('x || 'a >= 2) && Literal(3) === 'a, false, Seq(Literal(3) === 'a)) + testFilterReduction(('x || 'a >= 2) && Literal(3) <=> 'a, false, Seq(Literal(3) <=> 'a)) + testFilterReduction(('x || 'a >= 2) && Literal(3) >= 'a, false, Seq('x || Literal(2) <= 'a, 'a <= 3)) - testPropagation(('x || 'a >= 2) && Literal(3) > 'a, false, Seq('x || Literal(2) <= 'a, 'a < 3)) - testPropagation(('x || 'a > 2) && Literal(3) < 'a, false, Seq(Literal(3) < 'a)) - testPropagation(('x || 'a > 2) && Literal(3) <= 'a, false, Seq(Literal(3) <= 'a)) - testPropagation(('x || 'a > 2) && Literal(3) === 'a, false, Seq(Literal(3) === 'a)) - testPropagation(('x || 'a > 2) && Literal(3) <=> 'a, false, Seq(Literal(3) <=> 'a)) - testPropagation(('x || 'a > 2) && Literal(3) >= 'a, false, Seq('x || Literal(2) < 'a, 'a <= 3)) - testPropagation(('x || 'a > 2) && Literal(3) > 'a, false, Seq('x || Literal(2) < 'a, 'a < 3)) + testFilterReduction(('x || 'a >= 2) && Literal(3) > 'a, false, + Seq('x || Literal(2) <= 'a, 'a < 3)) + testFilterReduction(('x || 'a > 2) && Literal(3) < 'a, false, Seq(Literal(3) < 'a)) + testFilterReduction(('x || 'a > 2) && Literal(3) <= 'a, false, Seq(Literal(3) <= 'a)) + testFilterReduction(('x || 'a > 2) && Literal(3) === 'a, false, Seq(Literal(3) === 'a)) + testFilterReduction(('x || 'a > 2) && Literal(3) <=> 'a, false, Seq(Literal(3) <=> 'a)) + testFilterReduction(('x || 'a > 2) && Literal(3) >= 'a, false, + Seq('x || Literal(2) < 'a, 'a <= 3)) + testFilterReduction(('x || 'a > 2) && Literal(3) > 'a, false, + Seq('x || Literal(2) < 'a, 'a < 3)) - testPropagation('a < 'b && 'b < 'a, true) - testPropagation('a < 'b && 'b <= 'a, true) - testPropagation('a < 'b && 'b === 'a, true) - testPropagation('a < 'b && 'b <=> 'a, true) - testPropagation('a < 'b && 'b >= 'a, false, Seq('a < 'b)) - testPropagation('a < 'b && 'b > 'a, false, Seq('a < 'b)) - testPropagation('a <= 'b && 'b < 'a, true) - testPropagation('a <= 'b && 'b <= 'a, false, Seq('a === 'b)) - testPropagation('a <= 'b && 'b === 'a, false, Seq('a === 'b)) - testPropagation('a <= 'b && 'b <=> 'a, false, Seq('a === 'b)) - testPropagation('a <= 'b && 'b >= 'a, false, Seq('a <= 'b)) - testPropagation('a <= 'b && 'b > 'a, false, Seq('a < 'b)) - testPropagation('a === 'b && 'b < 'a, true) - testPropagation('a === 'b && 'b <= 'a, false, Seq('a === 'b)) - testPropagation('a === 'b && 'b === 'a, false, Seq('a === 'b)) - testPropagation('a === 'b && 'b <=> 'a, false, Seq('a === 'b)) - testPropagation('a === 'b && 'b >= 'a, false, Seq('a === 'b)) - testPropagation('a === 'b && 'b > 'a, true) - testPropagation('a <=> 'b && 'b < 'a, true) - testPropagation('a <=> 'b && 'b <= 'a, false, Seq('a === 'b)) - testPropagation('a <=> 'b && 'b === 'a, false, Seq('a === 'b)) - testPropagation('a <=> 'b && 'b <=> 'a, false, Seq('a <=> 'b)) - testPropagation('a <=> 'b && 'b >= 'a, false, Seq('a === 'b)) - testPropagation('a <=> 'b && 'b > 'a, true) - testPropagation('a >= 'b && 'b < 'a, false, Seq('b < 'a)) - testPropagation('a >= 'b && 'b <= 'a, false, Seq('b <= 'a)) - testPropagation('a >= 'b && 'b === 'a, false, Seq('a === 'b)) - testPropagation('a >= 'b && 'b <=> 'a, false, Seq('a === 'b)) - testPropagation('a >= 'b && 'b >= 'a, false, Seq('a === 'b)) - testPropagation('a >= 'b && 'b > 'a, true) - testPropagation('a > 'b && 'b < 'a, false, Seq('b < 'a)) - testPropagation('a > 'b && 'b <= 'a, false, Seq('b < 'a)) - testPropagation('a > 'b && 'b === 'a, true) - testPropagation('a > 'b && 'b <=> 'a, true) - testPropagation('a > 'b && 'b >= 'a, true) - testPropagation('a > 'b && 'b > 'a, true) + testFilterReduction('a < 'b && 'b < 'a, true) + testFilterReduction('a < 'b && 'b <= 'a, true) + testFilterReduction('a < 'b && 'b === 'a, true) + testFilterReduction('a < 'b && 'b <=> 'a, true) + testFilterReduction('a < 'b && 'b >= 'a, false, Seq('a < 'b)) + testFilterReduction('a < 'b && 'b > 'a, false, Seq('a < 'b)) + testFilterReduction('a <= 'b && 'b < 'a, true) + testFilterReduction('a <= 'b && 'b <= 'a, false, Seq('a === 'b)) + testFilterReduction('a <= 'b && 'b === 'a, false, Seq('a === 'b)) + testFilterReduction('a <= 'b && 'b <=> 'a, false, Seq('a === 'b)) + testFilterReduction('a <= 'b && 'b >= 'a, false, Seq('a <= 'b)) + testFilterReduction('a <= 'b && 'b > 'a, false, Seq('a < 'b)) + testFilterReduction('a === 'b && 'b < 'a, true) + testFilterReduction('a === 'b && 'b <= 'a, false, Seq('a === 'b)) + testFilterReduction('a === 'b && 'b === 'a, false, Seq('a === 'b)) + testFilterReduction('a === 'b && 'b <=> 'a, false, Seq('a === 'b)) + testFilterReduction('a === 'b && 'b >= 'a, false, Seq('a === 'b)) + testFilterReduction('a === 'b && 'b > 'a, true) + testFilterReduction('a <=> 'b && 'b < 'a, true) + testFilterReduction('a <=> 'b && 'b <= 'a, false, Seq('a === 'b)) + testFilterReduction('a <=> 'b && 'b === 'a, false, Seq('a === 'b)) + testFilterReduction('a <=> 'b && 'b <=> 'a, false, Seq('a <=> 'b)) + testFilterReduction('a <=> 'b && 'b >= 'a, false, Seq('a === 'b)) + testFilterReduction('a <=> 'b && 'b > 'a, true) + testFilterReduction('a >= 'b && 'b < 'a, false, Seq('b < 'a)) + testFilterReduction('a >= 'b && 'b <= 'a, false, Seq('b <= 'a)) + testFilterReduction('a >= 'b && 'b === 'a, false, Seq('a === 'b)) + testFilterReduction('a >= 'b && 'b <=> 'a, false, Seq('a === 'b)) + testFilterReduction('a >= 'b && 'b >= 'a, false, Seq('a === 'b)) + testFilterReduction('a >= 'b && 'b > 'a, true) + testFilterReduction('a > 'b && 'b < 'a, false, Seq('b < 'a)) + testFilterReduction('a > 'b && 'b <= 'a, false, Seq('b < 'a)) + testFilterReduction('a > 'b && 'b === 'a, true) + testFilterReduction('a > 'b && 'b <=> 'a, true) + testFilterReduction('a > 'b && 'b >= 'a, true) + testFilterReduction('a > 'b && 'b > 'a, true) - testPropagation('a < abs('b) && abs('b) < 'a, true) - testPropagation('a < abs('b) && abs('b) <= 'a, true) - testPropagation('a < abs('b) && abs('b) === 'a, true) - testPropagation('a < abs('b) && abs('b) <=> 'a, true) - testPropagation('a < abs('b) && abs('b) >= 'a, false, Seq('a < abs('b))) - testPropagation('a < abs('b) && abs('b) > 'a, false, Seq('a < abs('b))) - testPropagation('a <= abs('b) && abs('b) < 'a, true) - testPropagation('a <= abs('b) && abs('b) <= 'a, false, Seq('a === abs('b))) - testPropagation('a <= abs('b) && abs('b) === 'a, false, Seq('a === abs('b))) - testPropagation('a <= abs('b) && abs('b) <=> 'a, false, Seq('a === abs('b))) - testPropagation('a <= abs('b) && abs('b) >= 'a, false, Seq('a <= abs('b))) - testPropagation('a <= abs('b) && abs('b) > 'a, false, Seq('a < abs('b))) - testPropagation('a === abs('b) && abs('b) < 'a, true) - testPropagation('a === abs('b) && abs('b) <= 'a, false, Seq('a === abs('b))) - testPropagation('a === abs('b) && abs('b) === 'a, false, Seq('a === abs('b))) - testPropagation('a === abs('b) && abs('b) <=> 'a, false, Seq('a === abs('b))) - testPropagation('a === abs('b) && abs('b) >= 'a, false, Seq('a === abs('b))) - testPropagation('a === abs('b) && abs('b) > 'a, true) - testPropagation('a <=> abs('b) && abs('b) < 'a, true) - testPropagation('a <=> abs('b) && abs('b) <= 'a, false, Seq('a === abs('b))) - testPropagation('a <=> abs('b) && abs('b) === 'a, false, Seq('a === abs('b))) - testPropagation('a <=> abs('b) && abs('b) <=> 'a, false, Seq('a <=> abs('b))) - testPropagation('a <=> abs('b) && abs('b) >= 'a, false, Seq('a === abs('b))) - testPropagation('a <=> abs('b) && abs('b) > 'a, true) - testPropagation('a >= abs('b) && abs('b) < 'a, false, Seq(abs('b) < 'a)) - testPropagation('a >= abs('b) && abs('b) <= 'a, false, Seq(abs('b) <= 'a)) - testPropagation('a >= abs('b) && abs('b) === 'a, false, Seq('a === abs('b))) - testPropagation('a >= abs('b) && abs('b) <=> 'a, false, Seq('a === abs('b))) - testPropagation('a >= abs('b) && abs('b) >= 'a, false, Seq('a === abs('b))) - testPropagation('a >= abs('b) && abs('b) > 'a, true) - testPropagation('a > abs('b) && abs('b) < 'a, false, Seq(abs('b) < 'a)) - testPropagation('a > abs('b) && abs('b) <= 'a, false, Seq(abs('b) < 'a)) - testPropagation('a > abs('b) && abs('b) === 'a, true) - testPropagation('a > abs('b) && abs('b) <=> 'a, true) - testPropagation('a > abs('b) && abs('b) >= 'a, true) - testPropagation('a > abs('b) && abs('b) > 'a, true) + testFilterReduction('a < abs('b) && abs('b) < 'a, true) + testFilterReduction('a < abs('b) && abs('b) <= 'a, true) + testFilterReduction('a < abs('b) && abs('b) === 'a, true) + testFilterReduction('a < abs('b) && abs('b) <=> 'a, true) + testFilterReduction('a < abs('b) && abs('b) >= 'a, false, Seq('a < abs('b))) + testFilterReduction('a < abs('b) && abs('b) > 'a, false, Seq('a < abs('b))) + testFilterReduction('a <= abs('b) && abs('b) < 'a, true) + testFilterReduction('a <= abs('b) && abs('b) <= 'a, false, Seq('a === abs('b))) + testFilterReduction('a <= abs('b) && abs('b) === 'a, false, Seq('a === abs('b))) + testFilterReduction('a <= abs('b) && abs('b) <=> 'a, false, Seq('a === abs('b))) + testFilterReduction('a <= abs('b) && abs('b) >= 'a, false, Seq('a <= abs('b))) + testFilterReduction('a <= abs('b) && abs('b) > 'a, false, Seq('a < abs('b))) + testFilterReduction('a === abs('b) && abs('b) < 'a, true) + testFilterReduction('a === abs('b) && abs('b) <= 'a, false, Seq('a === abs('b))) + testFilterReduction('a === abs('b) && abs('b) === 'a, false, Seq('a === abs('b))) + testFilterReduction('a === abs('b) && abs('b) <=> 'a, false, Seq('a === abs('b))) + testFilterReduction('a === abs('b) && abs('b) >= 'a, false, Seq('a === abs('b))) + testFilterReduction('a === abs('b) && abs('b) > 'a, true) + testFilterReduction('a <=> abs('b) && abs('b) < 'a, true) + testFilterReduction('a <=> abs('b) && abs('b) <= 'a, false, Seq('a === abs('b))) + testFilterReduction('a <=> abs('b) && abs('b) === 'a, false, Seq('a === abs('b))) + testFilterReduction('a <=> abs('b) && abs('b) <=> 'a, false, Seq('a <=> abs('b))) + testFilterReduction('a <=> abs('b) && abs('b) >= 'a, false, Seq('a === abs('b))) + testFilterReduction('a <=> abs('b) && abs('b) > 'a, true) + testFilterReduction('a >= abs('b) && abs('b) < 'a, false, Seq(abs('b) < 'a)) + testFilterReduction('a >= abs('b) && abs('b) <= 'a, false, Seq(abs('b) <= 'a)) + testFilterReduction('a >= abs('b) && abs('b) === 'a, false, Seq('a === abs('b))) + testFilterReduction('a >= abs('b) && abs('b) <=> 'a, false, Seq('a === abs('b))) + testFilterReduction('a >= abs('b) && abs('b) >= 'a, false, Seq('a === abs('b))) + testFilterReduction('a >= abs('b) && abs('b) > 'a, true) + testFilterReduction('a > abs('b) && abs('b) < 'a, false, Seq(abs('b) < 'a)) + testFilterReduction('a > abs('b) && abs('b) <= 'a, false, Seq(abs('b) < 'a)) + testFilterReduction('a > abs('b) && abs('b) === 'a, true) + testFilterReduction('a > abs('b) && abs('b) <=> 'a, true) + testFilterReduction('a > abs('b) && abs('b) >= 'a, true) + testFilterReduction('a > abs('b) && abs('b) > 'a, true) - testPropagation(('x || 'a < abs('b)) && abs('b) < 'a, false, Seq('x, abs('b) < 'a)) - testPropagation(('x || 'a < abs('b)) && abs('b) <= 'a, false, Seq('x, abs('b) <= 'a)) - testPropagation(('x || 'a < abs('b)) && abs('b) === 'a, false, Seq('x, abs('b) === 'a)) - testPropagation(('x || 'a < abs('b)) && abs('b) <=> 'a, false, Seq('x, abs('b) <=> 'a)) - testPropagation(('x || 'a < abs('b)) && abs('b) >= 'a, false, + testFilterReduction(('x || 'a < abs('b)) && abs('b) < 'a, false, Seq('x, abs('b) < 'a)) + testFilterReduction(('x || 'a < abs('b)) && abs('b) <= 'a, false, Seq('x, abs('b) <= 'a)) + testFilterReduction(('x || 'a < abs('b)) && abs('b) === 'a, false, Seq('x, abs('b) === 'a)) + testFilterReduction(('x || 'a < abs('b)) && abs('b) <=> 'a, false, Seq('x, abs('b) <=> 'a)) + testFilterReduction(('x || 'a < abs('b)) && abs('b) >= 'a, false, Seq('x || 'a < abs('b), 'a <= abs('b))) - testPropagation(('x || 'a < abs('b)) && abs('b) > 'a, false, Seq('a < abs('b))) - testPropagation(('x || 'a <= abs('b)) && abs('b) < 'a, false, Seq('x, abs('b) < 'a)) - testPropagation(('x || 'a <= abs('b)) && abs('b) <= 'a, false, + testFilterReduction(('x || 'a < abs('b)) && abs('b) > 'a, false, Seq('a < abs('b))) + testFilterReduction(('x || 'a <= abs('b)) && abs('b) < 'a, false, Seq('x, abs('b) < 'a)) + testFilterReduction(('x || 'a <= abs('b)) && abs('b) <= 'a, false, Seq('x || 'a === abs('b), abs('b) <= 'a)) - testPropagation(('x || 'a <= abs('b)) && abs('b) === 'a, false, Seq(abs('b) === 'a)) - testPropagation(('x || 'a <= abs('b)) && abs('b) <=> 'a, false, + testFilterReduction(('x || 'a <= abs('b)) && abs('b) === 'a, false, Seq(abs('b) === 'a)) + testFilterReduction(('x || 'a <= abs('b)) && abs('b) <=> 'a, false, Seq('x || 'a === abs('b), abs('b) <=> 'a)) - testPropagation(('x || 'a <= abs('b)) && abs('b) >= 'a, false, Seq('a <= abs('b))) - testPropagation(('x || 'a <= abs('b)) && abs('b) > 'a, false, Seq('a < abs('b))) - testPropagation(('x || 'a === abs('b)) && abs('b) < 'a, false, Seq('x, abs('b) < 'a)) - testPropagation(('x || 'a === abs('b)) && abs('b) <= 'a, false, + testFilterReduction(('x || 'a <= abs('b)) && abs('b) >= 'a, false, Seq('a <= abs('b))) + testFilterReduction(('x || 'a <= abs('b)) && abs('b) > 'a, false, Seq('a < abs('b))) + testFilterReduction(('x || 'a === abs('b)) && abs('b) < 'a, false, Seq('x, abs('b) < 'a)) + testFilterReduction(('x || 'a === abs('b)) && abs('b) <= 'a, false, Seq('x || 'a === abs('b), abs('b) <= 'a)) - testPropagation(('x || 'a === abs('b)) && abs('b) === 'a, false, Seq(abs('b) === 'a)) - testPropagation(('x || 'a === abs('b)) && abs('b) <=> 'a, false, + testFilterReduction(('x || 'a === abs('b)) && abs('b) === 'a, false, Seq(abs('b) === 'a)) + testFilterReduction(('x || 'a === abs('b)) && abs('b) <=> 'a, false, Seq('x || 'a === abs('b), abs('b) <=> 'a)) - testPropagation(('x || 'a === abs('b)) && abs('b) >= 'a, false, + testFilterReduction(('x || 'a === abs('b)) && abs('b) >= 'a, false, Seq('x || 'a === abs('b), 'a <= abs('b))) - testPropagation(('x || 'a === abs('b)) && abs('b) > 'a, false, Seq('x, 'a < abs('b))) - testPropagation(('x || 'a <=> abs('b)) && abs('b) < 'a, false, Seq('x, abs('b) < 'a)) - testPropagation(('x || 'a <=> abs('b)) && abs('b) <= 'a, false, + testFilterReduction(('x || 'a === abs('b)) && abs('b) > 'a, false, Seq('x, 'a < abs('b))) + testFilterReduction(('x || 'a <=> abs('b)) && abs('b) < 'a, false, Seq('x, abs('b) < 'a)) + testFilterReduction(('x || 'a <=> abs('b)) && abs('b) <= 'a, false, Seq('x || 'a === abs('b), abs('b) <= 'a)) - testPropagation(('x || 'a <=> abs('b)) && abs('b) === 'a, false, Seq(abs('b) === 'a)) - testPropagation(('x || 'a <=> abs('b)) && abs('b) <=> 'a, false, Seq(abs('b) <=> 'a)) - testPropagation(('x || 'a <=> abs('b)) && abs('b) >= 'a, false, + testFilterReduction(('x || 'a <=> abs('b)) && abs('b) === 'a, false, Seq(abs('b) === 'a)) + testFilterReduction(('x || 'a <=> abs('b)) && abs('b) <=> 'a, false, Seq(abs('b) <=> 'a)) + testFilterReduction(('x || 'a <=> abs('b)) && abs('b) >= 'a, false, Seq('x || 'a === abs('b), 'a <= abs('b))) - testPropagation(('x || 'a <=> abs('b)) && abs('b) > 'a, false, Seq('x, 'a < abs('b))) - testPropagation(('x || 'a >= abs('b)) && abs('b) < 'a, false, Seq(abs('b) < 'a)) - testPropagation(('x || 'a >= abs('b)) && abs('b) <= 'a, false, Seq(abs('b) <= 'a)) - testPropagation(('x || 'a >= abs('b)) && abs('b) === 'a, false, Seq(abs('b) === 'a)) - testPropagation(('x || 'a >= abs('b)) && abs('b) <=> 'a, false, + testFilterReduction(('x || 'a <=> abs('b)) && abs('b) > 'a, false, Seq('x, 'a < abs('b))) + testFilterReduction(('x || 'a >= abs('b)) && abs('b) < 'a, false, Seq(abs('b) < 'a)) + testFilterReduction(('x || 'a >= abs('b)) && abs('b) <= 'a, false, Seq(abs('b) <= 'a)) + testFilterReduction(('x || 'a >= abs('b)) && abs('b) === 'a, false, Seq(abs('b) === 'a)) + testFilterReduction(('x || 'a >= abs('b)) && abs('b) <=> 'a, false, Seq('x || 'a === abs('b), abs('b) <=> 'a)) - testPropagation(('x || 'a >= abs('b)) && abs('b) >= 'a, false, + testFilterReduction(('x || 'a >= abs('b)) && abs('b) >= 'a, false, Seq('x || 'a === abs('b), 'a <= abs('b))) - testPropagation(('x || 'a >= abs('b)) && abs('b) > 'a, false, Seq('x, 'a < abs('b))) - testPropagation(('x || 'a > abs('b)) && abs('b) < 'a, false, Seq(abs('b) < 'a)) - testPropagation(('x || 'a > abs('b)) && abs('b) <= 'a, false, + testFilterReduction(('x || 'a >= abs('b)) && abs('b) > 'a, false, Seq('x, 'a < abs('b))) + testFilterReduction(('x || 'a > abs('b)) && abs('b) < 'a, false, Seq(abs('b) < 'a)) + testFilterReduction(('x || 'a > abs('b)) && abs('b) <= 'a, false, Seq('x || abs('b) < 'a, abs('b) <= 'a)) - testPropagation(('x || 'a > abs('b)) && abs('b) === 'a, false, Seq('x, abs('b) === 'a)) - testPropagation(('x || 'a > abs('b)) && abs('b) <=> 'a, false, Seq('x, abs('b) <=> 'a)) - testPropagation(('x || 'a > abs('b)) && abs('b) >= 'a, false, Seq('x, 'a <= abs('b))) - testPropagation(('x || 'a > abs('b)) && abs('b) > 'a, false, Seq('x, 'a < abs('b))) + testFilterReduction(('x || 'a > abs('b)) && abs('b) === 'a, false, Seq('x, abs('b) === 'a)) + testFilterReduction(('x || 'a > abs('b)) && abs('b) <=> 'a, false, Seq('x, abs('b) <=> 'a)) + testFilterReduction(('x || 'a > abs('b)) && abs('b) >= 'a, false, Seq('x, 'a <= abs('b))) + testFilterReduction(('x || 'a > abs('b)) && abs('b) > 'a, false, Seq('x, 'a < abs('b))) } } From 699fe2fbaf81a0da3b6c09f998bb0f02c91671af Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Fri, 10 May 2019 15:09:41 +0200 Subject: [PATCH 11/12] fix null issues --- .../sql/catalyst/optimizer/expressions.scala | 335 +++++++++--------- .../optimizer/FilterReductionSuite.scala | 126 ++++++- .../spark/sql/catalyst/plans/PlanTest.scala | 48 ++- 3 files changed, 317 insertions(+), 192 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 7329727c629e..386542583d84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -171,7 +171,7 @@ object FilterReduction extends Rule[LogicalPlan] with ConstraintHelper { } private def normalizeAndReduceWithConstraints(expression: Expression): Expression = - reduceWithConstraints(normalize(expression))._1 + reduceWithConstraints(normalize(expression), true)._1 private def normalize(expression: Expression) = expression transform { case GreaterThan(x, y) => LessThan(y, x) @@ -183,24 +183,32 @@ object FilterReduction extends Rule[LogicalPlan] with ConstraintHelper { * - This functions assumes that the plan has been normalized using [[normalize()]] * - On matching [[And]], recursively traverse both children, simplify child expressions with * propagated constraints from sibling and propagate up union of constraints. - * - If a child of [[And]] is [[LessThan]], [[LessThanOrEqual]], [[EqualTo]], [[EqualNullSafe]], + * - If a child of [[And]] is [[LessThan]], [[LessThanOrEqual]], [[EqualTo]] or [[EqualNullSafe]], * propagate the constraint. - * - On matching [[Or]] or [[Not]], recursively traverse each children, propagate no constraints. + * - On matching [[Or]], [[If]], [[CaseWhen]] or [[Not]] recursively traverse each children, but + * propagate up no constraints. + * - Starting off from a condition expression of a [[Filter]] as top node, to the bottom of the + * expression tree, [[And]], [[Or]], [[If]] and [[CaseWhen]] nodes are considered as safe, + * non-[[NullIntolerant]] nodes where reduction rules can be executed without nullability check. * - Otherwise, stop traversal and propagate no constraints. * @param expression expression to be traversed + * @param nullIsFalse defines if a null value can be considered as false * @return A tuple including: * 1. Expression: optionally changed expression after traversal * 2. Seq[Expression]: propagated constraints */ - private def reduceWithConstraints(expression: Expression): (Expression, Seq[Expression]) = + private def reduceWithConstraints( + expression: Expression, + nullIsFalse: Boolean): (Expression, Seq[Expression]) = expression match { case e @ (_: LessThan | _: LessThanOrEqual | _: EqualTo | _: EqualNullSafe) if e.deterministic => (e, Seq(e)) case a @ And(left, right) => - val (newLeft, leftConstraints) = reduceWithConstraints(left) - val reducedRight = reduceWithConstraints(right, leftConstraints) - val (reducedNewRight, rightConstraints) = reduceWithConstraints(reducedRight) - val reducedNewLeft = reduceWithConstraints(newLeft, rightConstraints) + val (newLeft, leftConstraints) = reduceWithConstraints(left, nullIsFalse) + val reducedRight = reduceWithConstraints(right, leftConstraints, nullIsFalse) + val (reducedNewRight, rightConstraints) = + reduceWithConstraints(reducedRight, nullIsFalse) + val reducedNewLeft = reduceWithConstraints(newLeft, rightConstraints, nullIsFalse) val newAnd = if ((reducedNewLeft fastEquals left) && (reducedNewRight fastEquals right)) { a @@ -208,31 +216,19 @@ object FilterReduction extends Rule[LogicalPlan] with ConstraintHelper { And(reducedNewLeft, reducedNewRight) } (newAnd, leftConstraints ++ rightConstraints) - case o @ Or(left, right) => - // Ignore the EqualityPredicates from children since they are only propagated through And. - val (newLeft, _) = reduceWithConstraints(left) - val (newRight, _) = reduceWithConstraints(right) - val newOr = if ((newLeft fastEquals left) && (newRight fastEquals right)) { - o - } else { - Or(newLeft, newRight) - } - - (newOr, Seq.empty) - case n @ Not(child) => - // Ignore the EqualityPredicates from children since they are only propagated through And. - val (newChild, _) = reduceWithConstraints(child) - val newNot = if (newChild fastEquals child) { - n - } else { - Not(newChild) - } - (newNot, Seq.empty) + case o @ (_: Or | _: If | _: CaseWhen) => + (o.mapChildren(reduceWithConstraints(_, nullIsFalse)._1), Seq.empty) + case n: Not => + (n.mapChildren(reduceWithConstraints(_, false)._1), Seq.empty) case _ => (expression, Seq.empty) } - private def reduceWithConstraints(expression: Expression, constraints: Seq[Expression]) = - constraints.foldLeft(expression)((e, constraint) => reduceWithConstraint(e, constraint)) + private def reduceWithConstraints( + expression: Expression, + constraints: Seq[Expression], + nullIsFalse: Boolean) = + constraints + .foldLeft(expression)((e, constraint) => reduceWithConstraint(e, constraint, nullIsFalse)) private def planEqual(x: Expression, y: Expression) = !x.foldable && !y.foldable && x.canonicalized == y.canonicalized @@ -246,140 +242,155 @@ object FilterReduction extends Rule[LogicalPlan] with ConstraintHelper { private def valueLessThanOrEqual(x: Expression, y: Expression) = x.foldable && y.foldable && LessThanOrEqual(x, y).eval(EmptyRow).asInstanceOf[Boolean] - private def reduceWithConstraint(expression: Expression, constraint: Expression): Expression = - constraint match { - case a LessThan b => expression transformUp { - case c LessThan d if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => - Literal.TrueLiteral - case c LessThan d if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => - Literal.FalseLiteral - case c LessThan d if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => - Literal.TrueLiteral - case c LessThan d if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => - Literal.FalseLiteral - - case c LessThanOrEqual d - if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => - Literal.TrueLiteral - case c LessThanOrEqual d - if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => - Literal.FalseLiteral - case c LessThanOrEqual d - if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => - Literal.TrueLiteral - case c LessThanOrEqual d - if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => - Literal.FalseLiteral - - case c EqualTo d if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => - Literal.FalseLiteral - case c EqualTo d if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => - Literal.FalseLiteral - case c EqualTo d if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => - Literal.FalseLiteral - case c EqualTo d if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => - Literal.FalseLiteral - - case c EqualNullSafe d - if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => - Literal.FalseLiteral - case c EqualNullSafe d - if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => - Literal.FalseLiteral - case c EqualNullSafe d - if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => - Literal.FalseLiteral - case c EqualNullSafe d - if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => - Literal.FalseLiteral - - case c EqualNullSafe d if planEqual(b, d) => EqualTo(c, d) - case c EqualNullSafe d if planEqual(b, c) => EqualTo(c, d) - case c EqualNullSafe d if planEqual(a, c) => EqualTo(c, d) - case c EqualNullSafe d if planEqual(a, d) => EqualTo(c, d) - } - case a LessThanOrEqual b => expression transformUp { - case c LessThan d if planEqual(b, d) && valueLessThan(c, a) => - Literal.TrueLiteral - case c LessThan d if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => - Literal.FalseLiteral - case c LessThan d if planEqual(a, c) && valueLessThan(b, d) => - Literal.TrueLiteral - case c LessThan d if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => - Literal.FalseLiteral - - case c LessThanOrEqual d - if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => - Literal.TrueLiteral - case c LessThanOrEqual d if planEqual(b, c) && valueLessThan(d, a) => - Literal.FalseLiteral - case c LessThanOrEqual d if planEqual(b, c) && (planEqual(a, d) || valueEqual(a, d)) => - EqualTo(c, d) - case c LessThanOrEqual d - if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => - Literal.TrueLiteral - case c LessThanOrEqual d if planEqual(a, d) && valueLessThan(b, c) => - Literal.FalseLiteral - case c LessThanOrEqual d if planEqual(a, d) && (planEqual(b, c) || valueEqual(b, c)) => - EqualTo(c, d) - - case c EqualTo d if planEqual(b, d) && valueLessThan(c, a) => Literal.FalseLiteral - case c EqualTo d if planEqual(b, c) && valueLessThan(d, a) => Literal.FalseLiteral - case c EqualTo d if planEqual(a, c) && valueLessThan(b, d) => Literal.FalseLiteral - case c EqualTo d if planEqual(a, d) && valueLessThan(b, c) => Literal.FalseLiteral - - case c EqualNullSafe d if planEqual(b, d) && valueLessThan(c, a) => Literal.FalseLiteral - case c EqualNullSafe d if planEqual(b, c) && valueLessThan(d, a) => Literal.FalseLiteral - case c EqualNullSafe d if planEqual(a, c) && valueLessThan(b, d) => Literal.FalseLiteral - case c EqualNullSafe d if planEqual(a, d) && valueLessThan(b, c) => Literal.FalseLiteral - - case c EqualNullSafe d if planEqual(b, d) => EqualTo(c, d) - case c EqualNullSafe d if planEqual(b, c) => EqualTo(c, d) - case c EqualNullSafe d if planEqual(a, c) => EqualTo(c, d) - case c EqualNullSafe d if planEqual(a, d) => EqualTo(c, d) - } - case a EqualTo b => expression transformUp { - case c LessThan d if planEqual(b, d) && planEqual(a, c) => Literal.FalseLiteral - case c LessThan d if planEqual(b, c) && planEqual(a, d) => Literal.FalseLiteral - case c LessThan d if planEqual(a, d) && planEqual(b, c) => Literal.FalseLiteral - case c LessThan d if planEqual(a, c) && planEqual(b, d) => Literal.FalseLiteral - - case c LessThanOrEqual d if planEqual(b, d) && planEqual(a, c) => Literal.TrueLiteral - case c LessThanOrEqual d if planEqual(b, c) && planEqual(a, d) => Literal.TrueLiteral - case c LessThanOrEqual d if planEqual(a, d) && planEqual(b, c) => Literal.TrueLiteral - case c LessThanOrEqual d if planEqual(a, c) && planEqual(b, d) => Literal.TrueLiteral - - case c EqualTo d if planEqual(b, d) && planEqual(a, c) => Literal.TrueLiteral - case c EqualTo d if planEqual(b, c) && planEqual(a, d) => Literal.TrueLiteral - case c EqualTo d if planEqual(a, d) && planEqual(b, c) => Literal.TrueLiteral - case c EqualTo d if planEqual(a, c) && planEqual(b, d) => Literal.TrueLiteral - - case c EqualNullSafe d if planEqual(b, d) => - if (planEqual(a, c)) Literal.TrueLiteral else EqualTo(c, d) - case c EqualNullSafe d if planEqual(b, c) => - if (planEqual(a, d)) Literal.TrueLiteral else EqualTo(c, d) - case c EqualNullSafe d if planEqual(a, d) => - if (planEqual(b, c)) Literal.TrueLiteral else EqualTo(c, d) - case c EqualNullSafe d if planEqual(a, c) => - if (planEqual(b, d)) Literal.TrueLiteral else EqualTo(c, d) + private def reduceWithConstraint( + expression: Expression, + constraint: Expression, + nullIsFalse: Boolean): Expression = + if (nullIsFalse || constraint.children.forall(!_.nullable)) { + constraint match { + case a LessThan b => expression transformUp { + case c LessThan d if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => + Literal.TrueLiteral + case c LessThan d if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => + Literal.FalseLiteral + case c LessThan d if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => + Literal.TrueLiteral + case c LessThan d if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => + Literal.FalseLiteral + + case c LessThanOrEqual d + if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => + Literal.TrueLiteral + case c LessThanOrEqual d + if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => + Literal.FalseLiteral + case c LessThanOrEqual d + if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => + Literal.TrueLiteral + case c LessThanOrEqual d + if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => + Literal.FalseLiteral + + case c EqualTo d if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => + Literal.FalseLiteral + case c EqualTo d if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => + Literal.FalseLiteral + case c EqualTo d if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => + Literal.FalseLiteral + case c EqualTo d if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => + Literal.FalseLiteral + + case c EqualNullSafe d + if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => + Literal.FalseLiteral + case c EqualNullSafe d + if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => + Literal.FalseLiteral + case c EqualNullSafe d + if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => + Literal.FalseLiteral + case c EqualNullSafe d + if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => + Literal.FalseLiteral + + case c EqualNullSafe d if planEqual(b, d) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(b, c) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, c) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, d) => EqualTo(c, d) + } + case a LessThanOrEqual b => expression transformUp { + case c LessThan d if planEqual(b, d) && valueLessThan(c, a) => + Literal.TrueLiteral + case c LessThan d if planEqual(b, c) && (planEqual(a, d) || valueLessThanOrEqual(d, a)) => + Literal.FalseLiteral + case c LessThan d if planEqual(a, c) && valueLessThan(b, d) => + Literal.TrueLiteral + case c LessThan d if planEqual(a, d) && (planEqual(b, c) || valueLessThanOrEqual(b, c)) => + Literal.FalseLiteral + + case c LessThanOrEqual d + if planEqual(b, d) && (planEqual(a, c) || valueLessThanOrEqual(c, a)) => + Literal.TrueLiteral + case c LessThanOrEqual d if planEqual(b, c) && valueLessThan(d, a) => + Literal.FalseLiteral + case c LessThanOrEqual d if planEqual(b, c) && (planEqual(a, d) || valueEqual(a, d)) => + EqualTo(c, d) + case c LessThanOrEqual d + if planEqual(a, c) && (planEqual(b, d) || valueLessThanOrEqual(b, d)) => + Literal.TrueLiteral + case c LessThanOrEqual d if planEqual(a, d) && valueLessThan(b, c) => + Literal.FalseLiteral + case c LessThanOrEqual d if planEqual(a, d) && (planEqual(b, c) || valueEqual(b, c)) => + EqualTo(c, d) + + case c EqualTo d if planEqual(b, d) && valueLessThan(c, a) => Literal.FalseLiteral + case c EqualTo d if planEqual(b, c) && valueLessThan(d, a) => Literal.FalseLiteral + case c EqualTo d if planEqual(a, c) && valueLessThan(b, d) => Literal.FalseLiteral + case c EqualTo d if planEqual(a, d) && valueLessThan(b, c) => Literal.FalseLiteral + + case c EqualNullSafe d if planEqual(b, d) && valueLessThan(c, a) => Literal.FalseLiteral + case c EqualNullSafe d if planEqual(b, c) && valueLessThan(d, a) => Literal.FalseLiteral + case c EqualNullSafe d if planEqual(a, c) && valueLessThan(b, d) => Literal.FalseLiteral + case c EqualNullSafe d if planEqual(a, d) && valueLessThan(b, c) => Literal.FalseLiteral + + case c EqualNullSafe d if planEqual(b, d) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(b, c) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, c) => EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, d) => EqualTo(c, d) + } + case a EqualTo b => expression transformUp { + case c LessThan d if planEqual(b, d) && planEqual(a, c) => Literal.FalseLiteral + case c LessThan d if planEqual(b, c) && planEqual(a, d) => Literal.FalseLiteral + case c LessThan d if planEqual(a, d) && planEqual(b, c) => Literal.FalseLiteral + case c LessThan d if planEqual(a, c) && planEqual(b, d) => Literal.FalseLiteral + + case c LessThanOrEqual d if planEqual(b, d) && planEqual(a, c) => Literal.TrueLiteral + case c LessThanOrEqual d if planEqual(b, c) && planEqual(a, d) => Literal.TrueLiteral + case c LessThanOrEqual d if planEqual(a, d) && planEqual(b, c) => Literal.TrueLiteral + case c LessThanOrEqual d if planEqual(a, c) && planEqual(b, d) => Literal.TrueLiteral + + case c EqualTo d if planEqual(b, d) && planEqual(a, c) => Literal.TrueLiteral + case c EqualTo d if planEqual(b, c) && planEqual(a, d) => Literal.TrueLiteral + case c EqualTo d if planEqual(a, d) && planEqual(b, c) => Literal.TrueLiteral + case c EqualTo d if planEqual(a, c) && planEqual(b, d) => Literal.TrueLiteral + + case c EqualNullSafe d if planEqual(b, d) => + if (planEqual(a, c)) Literal.TrueLiteral else EqualTo(c, d) + case c EqualNullSafe d if planEqual(b, c) => + if (planEqual(a, d)) Literal.TrueLiteral else EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, d) => + if (planEqual(b, c)) Literal.TrueLiteral else EqualTo(c, d) + case c EqualNullSafe d if planEqual(a, c) => + if (planEqual(b, d)) Literal.TrueLiteral else EqualTo(c, d) + } + case a EqualNullSafe b => expression transformUp { + case c LessThan d if planEqual(b, d) && planEqual(a, c) => Literal.FalseLiteral + case c LessThan d if planEqual(b, c) && planEqual(d, a) => Literal.FalseLiteral + case c LessThan d if planEqual(a, d) && planEqual(b, c) => Literal.FalseLiteral + case c LessThan d if planEqual(a, c) && planEqual(d, b) => Literal.FalseLiteral + + case c LessThanOrEqual d if planEqual(b, d) && planEqual(a, c) => EqualTo(c, d) + case c LessThanOrEqual d if planEqual(b, c) && planEqual(a, d) => EqualTo(c, d) + case c LessThanOrEqual d if planEqual(a, d) && planEqual(b, c) => EqualTo(c, d) + case c LessThanOrEqual d if planEqual(a, c) && planEqual(b, d) => EqualTo(c, d) + + case c EqualNullSafe d if planEqual(b, d) && planEqual(a, c) => Literal.TrueLiteral + case c EqualNullSafe d if planEqual(b, c) && planEqual(a, d) => Literal.TrueLiteral + case c EqualNullSafe d if planEqual(a, d) && planEqual(b, c) => Literal.TrueLiteral + case c EqualNullSafe d if planEqual(a, c) && planEqual(b, d) => Literal.TrueLiteral + } + case _ => expression } - case a EqualNullSafe b => expression transformUp { - case c LessThan d if planEqual(b, d) && planEqual(a, c) => Literal.FalseLiteral - case c LessThan d if planEqual(b, c) && planEqual(d, a) => Literal.FalseLiteral - case c LessThan d if planEqual(a, d) && planEqual(b, c) => Literal.FalseLiteral - case c LessThan d if planEqual(a, c) && planEqual(d, b) => Literal.FalseLiteral - - case c LessThanOrEqual d if planEqual(b, d) && planEqual(a, c) => EqualTo(c, d) - case c LessThanOrEqual d if planEqual(b, c) && planEqual(a, d) => EqualTo(c, d) - case c LessThanOrEqual d if planEqual(a, d) && planEqual(b, c) => EqualTo(c, d) - case c LessThanOrEqual d if planEqual(a, c) && planEqual(b, d) => EqualTo(c, d) - - case c EqualNullSafe d if planEqual(b, d) && planEqual(a, c) => Literal.TrueLiteral - case c EqualNullSafe d if planEqual(b, c) && planEqual(a, d) => Literal.TrueLiteral - case c EqualNullSafe d if planEqual(a, d) && planEqual(b, c) => Literal.TrueLiteral - case c EqualNullSafe d if planEqual(a, c) && planEqual(b, d) => Literal.TrueLiteral + } else { + constraint match { + case a EqualNullSafe b => expression transformUp { + case c EqualNullSafe d if planEqual(b, d) && planEqual(a, c) => Literal.TrueLiteral + case c EqualNullSafe d if planEqual(b, c) && planEqual(a, d) => Literal.TrueLiteral + case c EqualNullSafe d if planEqual(a, d) && planEqual(b, c) => Literal.TrueLiteral + case c EqualNullSafe d if planEqual(a, c) && planEqual(b, d) => Literal.TrueLiteral + } + case _ => expression } - case _ => expression } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterReductionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterReductionSuite.scala index 85c5a2cb4104..fad2ab17c577 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterReductionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterReductionSuite.scala @@ -31,8 +31,8 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor */ class FilterReductionSuite extends PlanTest { - object Optimize extends RuleExecutor[LogicalPlan] { - val batches = + trait OptimizeBase extends RuleExecutor[LogicalPlan] { + protected def batches = Batch("AnalysisNodes", Once, EliminateSubqueryAliases) :: Batch("FilterReduction", FixedPoint(10), @@ -44,17 +44,26 @@ class FilterReductionSuite extends PlanTest { PruneFilters) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'x.boolean) + object Optimize extends OptimizeBase + + object OptimizeWithoutFilterReduction extends OptimizeBase { + override protected def batches = + super.batches.map(b => Batch(b.name, b.strategy, b.rules.filterNot(_ == FilterReduction): _*)) + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int.notNull, 'd.int.notNull, 'x.boolean) val data = { - val intElements = Seq(null, 1, 2, 3) - val booleanElements = Seq(null, true, false) + val intElementsWithNull = Seq(null, 1, 2, 3, 4) + val intElementsWithoutNull = Seq(null, 1, 2, 3, 4) + val booleanElementsWithNull = Seq(null, true, false) for { - a <- intElements - b <- intElements - c <- intElements - x <- booleanElements - } yield (a, b, c, x) + a <- intElementsWithNull + b <- intElementsWithNull + c <- intElementsWithoutNull + d <- intElementsWithoutNull + x <- booleanElementsWithNull + } yield (a, b, c, d, x) } val testRelationWithData = LocalRelation.fromExternalRows(testRelation.output, data.map(Row(_))) @@ -67,13 +76,22 @@ class FilterReductionSuite extends PlanTest { val optimized = Optimize.execute(originalQuery) val correctAnswer = if (expectEmptyRelation) { testRelation + } else if (expectedConstraints.isEmpty) { + testRelationWithData } else { testRelationWithData.where(expectedConstraints.reduce(And)).analyze } comparePlans(optimized, correctAnswer) } - test("Filter reduction") { + private def testSameAsWithoutFilterReduction(input: Expression) = { + val originalQuery = testRelationWithData.where(input).analyze + val optimized = Optimize.execute(originalQuery) + val correctAnswer = OptimizeWithoutFilterReduction.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("Filter reduction with nullable attributes") { testFilterReduction('a < 2 && Literal(2) < 'a, true) testFilterReduction('a < 2 && Literal(2) <= 'a, true) testFilterReduction('a < 2 && Literal(2) === 'a, true) @@ -353,4 +371,90 @@ class FilterReductionSuite extends PlanTest { testFilterReduction(('x || 'a > abs('b)) && abs('b) >= 'a, false, Seq('x, 'a <= abs('b))) testFilterReduction(('x || 'a > abs('b)) && abs('b) > 'a, false, Seq('x, 'a < abs('b))) } + + // These cases test scenarios when there is NullIntolerant node (ex. Not) between Filter and the + // subtree of And nodes and the expression to be reduced is nullable. + // For example in these cases the following reduction does not hold when X and Y is null and so + // FilterReduction should do nothing: + // Not(X < Y && Y < X) => Not(X < Y && false) + // Not(X < Y && Not(X < Y)) => Not(X < Y && Not(true)) + // Not(X <=> Y && Y < X) => Not(X <=> Y && false) + // Not(X < Y && Y <=> X) => Not(X < Y && false) + test("Filter reduction with nullable attributes and NullIntolerant nodes") { + testSameAsWithoutFilterReduction(Not(('x || 'a < abs('b)) && abs('b) < 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a < abs('b)) && abs('b) <= 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a < abs('b)) && abs('b) === 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a < abs('b)) && abs('b) <=> 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a < abs('b)) && abs('b) >= 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a < abs('b)) && abs('b) > 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a <= abs('b)) && abs('b) < 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a <= abs('b)) && abs('b) <= 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a <= abs('b)) && abs('b) === 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a <= abs('b)) && abs('b) <=> 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a <= abs('b)) && abs('b) >= 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a <= abs('b)) && abs('b) > 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a === abs('b)) && abs('b) < 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a === abs('b)) && abs('b) <= 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a === abs('b)) && abs('b) === 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a === abs('b)) && abs('b) <=> 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a === abs('b)) && abs('b) >= 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a === abs('b)) && abs('b) > 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a <=> abs('b)) && abs('b) < 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a <=> abs('b)) && abs('b) <= 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a <=> abs('b)) && abs('b) === 'a)) + // the only exception: + // testSameAsWithoutFilterReduction(Not(('x || 'a <=> abs('b)) && abs('b) <=> 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a <=> abs('b)) && abs('b) >= 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a <=> abs('b)) && abs('b) > 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a >= abs('b)) && abs('b) < 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a >= abs('b)) && abs('b) <= 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a >= abs('b)) && abs('b) === 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a >= abs('b)) && abs('b) <=> 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a >= abs('b)) && abs('b) >= 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a >= abs('b)) && abs('b) > 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a > abs('b)) && abs('b) < 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a > abs('b)) && abs('b) <= 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a > abs('b)) && abs('b) === 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a > abs('b)) && abs('b) <=> 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a > abs('b)) && abs('b) >= 'a)) + testSameAsWithoutFilterReduction(Not(('x || 'a > abs('b)) && abs('b) > 'a)) + + testSameAsWithoutFilterReduction(Not(('x || Not('a < abs('b))) && abs('b) < 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a < abs('b))) && abs('b) <= 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a < abs('b))) && abs('b) === 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a < abs('b))) && abs('b) <=> 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a < abs('b))) && abs('b) >= 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a < abs('b))) && abs('b) > 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a <= abs('b))) && abs('b) < 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a <= abs('b))) && abs('b) <= 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a <= abs('b))) && abs('b) === 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a <= abs('b))) && abs('b) <=> 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a <= abs('b))) && abs('b) >= 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a <= abs('b))) && abs('b) > 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a === abs('b))) && abs('b) < 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a === abs('b))) && abs('b) <= 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a === abs('b))) && abs('b) === 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a === abs('b))) && abs('b) <=> 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a === abs('b))) && abs('b) >= 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a === abs('b))) && abs('b) > 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a <=> abs('b))) && abs('b) < 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a <=> abs('b))) && abs('b) <= 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a <=> abs('b))) && abs('b) === 'a)) + // the only exception: + // testSameAsWithoutFilterReduction(Not(('x || Not('a <=> abs('b))) && abs('b) <=> 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a <=> abs('b))) && abs('b) >= 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a <=> abs('b))) && abs('b) > 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a >= abs('b))) && abs('b) < 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a >= abs('b))) && abs('b) <= 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a >= abs('b))) && abs('b) === 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a >= abs('b))) && abs('b) <=> 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a >= abs('b))) && abs('b) >= 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a >= abs('b))) && abs('b) > 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a > abs('b))) && abs('b) < 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a > abs('b))) && abs('b) <= 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a > abs('b))) && abs('b) === 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a > abs('b))) && abs('b) <=> 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a > abs('b))) && abs('b) >= 'a)) + testSameAsWithoutFilterReduction(Not(('x || Not('a > abs('b))) && abs('b) > 'a)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index ecb69c109951..b7ccf04cd5da 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -95,39 +95,49 @@ trait PlanTestBase extends PredicateHelper with SQLHelper { self: Suite => protected def normalizePlan(plan: LogicalPlan): LogicalPlan = { plan transform { case Filter(condition: Expression, child: LogicalPlan) => - val newCondition = - splitConjunctivePredicates(condition) - .map(rewriteEqual) - .sortBy(p => scala.util.hashing.MurmurHash3.seqHash(Seq(p.getClass, p))) - .reduce(And) - Filter(newCondition, child) + Filter(normalize(condition), child) case sample: Sample => sample.copy(seed = 0L) case Join(left, right, joinType, condition, hint) if condition.isDefined => - val newCondition = - splitConjunctivePredicates(condition.get) - .map(rewriteEqual) - .sortBy(p => scala.util.hashing.MurmurHash3.seqHash(Seq(p.getClass, p))) - .reduce(And) - Join(left, right, joinType, Some(newCondition), hint) + Join(left, right, joinType, condition.map(normalize), hint) } } /** - * Rewrite [[EqualTo]] and [[EqualNullSafe]] operator to keep order. The following cases will be - * equivalent: - * 1. (a = b), (b = a); - * 2. (a <=> b), (b <=> a). + * Rewrite [[EqualTo]], [[EqualNullSafe]], [[GreaterThan]], [[GreaterThanOrEqual]], [[And]] and + * [[Or]] operators to keep order. + * The following pairs will be equivalent: + * 1. (a = b) and (b = a), + * 2. (a <=> b) and (b <=> a), + * 3. (a > b) and (b < a), + * 4. (a >= b) and (b <= a), + * 5. (a <= b AND b <= a) and (b <= a AND a <= b), + * 6. (a <= b OR b <= a) and (b <= a OR a <= b) */ - private def rewriteEqual(condition: Expression): Expression = condition transform { - case eq @ EqualTo(l: Expression, r: Expression) => + private def normalize(expression: Expression): Expression = expression match { + case EqualTo(l: Expression, r: Expression) => Seq(l, r) + .map(normalize) .sortBy(p => scala.util.hashing.MurmurHash3.seqHash(Seq(p.getClass, p))) .reduce(EqualTo) - case eq @ EqualNullSafe(l: Expression, r: Expression) => + case EqualNullSafe(l: Expression, r: Expression) => Seq(l, r) + .map(normalize) .sortBy(p => scala.util.hashing.MurmurHash3.seqHash(Seq(p.getClass, p))) .reduce(EqualNullSafe) + case GreaterThan(l, r) => LessThan(normalize(r), normalize(l)) + case GreaterThanOrEqual(l, r) => LessThanOrEqual(normalize(r), normalize(l)) + case and: And => + splitConjunctivePredicates(and) + .map(normalize) + .sortBy(p => scala.util.hashing.MurmurHash3.seqHash(Seq(p.getClass, p))) + .reduce(And) + case or: Or => + splitDisjunctivePredicates(or) + .map(normalize) + .sortBy(p => scala.util.hashing.MurmurHash3.seqHash(Seq(p.getClass, p))) + .reduce(Or) + case _ => expression } /** Fails the test if the two plans do not match */ From 0972a296d6e9d37cb0def385eae96749b127f3a2 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Fri, 10 May 2019 17:07:55 +0200 Subject: [PATCH 12/12] plantest fix --- .../scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index b7ccf04cd5da..32de8b6d88cb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -137,7 +137,7 @@ trait PlanTestBase extends PredicateHelper with SQLHelper { self: Suite => .map(normalize) .sortBy(p => scala.util.hashing.MurmurHash3.seqHash(Seq(p.getClass, p))) .reduce(Or) - case _ => expression + case _ => expression.mapChildren(normalize) } /** Fails the test if the two plans do not match */