diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index c4243da7b9e4..8919d19fc972 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -60,18 +60,25 @@ trait ConstraintHelper { */ def inferAdditionalConstraints(constraints: ExpressionSet): ExpressionSet = { var inferredConstraints = ExpressionSet() - // IsNotNull should be constructed by `constructIsNotNullConstraints`. - val predicates = constraints.filterNot(_.isInstanceOf[IsNotNull]) - predicates.foreach { - case eq @ EqualTo(l: Attribute, r: Attribute) => - val candidateConstraints = predicates - eq - inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) - inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) - case eq @ EqualTo(l @ Cast(_: Attribute, _, _), r: Attribute) => - inferredConstraints ++= replaceConstraints(predicates - eq, r, l) - case eq @ EqualTo(l: Attribute, r @ Cast(_: Attribute, _, _)) => - inferredConstraints ++= replaceConstraints(predicates - eq, l, r) - case _ => // No inference + var prevSize = -1 + while (inferredConstraints.size > prevSize) { + prevSize = inferredConstraints.size + val predicates = (constraints ++ inferredConstraints) + // IsNotNull should be constructed by `constructIsNotNullConstraints`. + .filterNot(_.isInstanceOf[IsNotNull]) + // Non deterministic expressions are all not equal and would cause OOM + .filter(_.deterministic) + predicates.foreach { + case eq @ Equality(l: Attribute, r: Attribute) => + val candidateConstraints = predicates - eq + inferredConstraints ++= replaceConstraints(candidateConstraints, l, r) + inferredConstraints ++= replaceConstraints(candidateConstraints, r, l) + case eq @ Equality(l @ Cast(_: Attribute, _, _), r: Attribute) => + inferredConstraints ++= replaceConstraints(predicates - eq, r, l) + case eq @ Equality(l: Attribute, r @ Cast(_: Attribute, _, _)) => + inferredConstraints ++= replaceConstraints(predicates - eq, l, r) + case _ => // No inference + } } inferredConstraints -- constraints } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 8c111aa75080..e09207a0ce76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -351,6 +351,7 @@ case class Join( .union(ExpressionSet(splitConjunctivePredicates(condition.get))) case LeftSemi if condition.isDefined => left.constraints + .union(right.constraints) .union(ExpressionSet(splitConjunctivePredicates(condition.get))) case j: ExistenceJoin => left.constraints diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 79bd573f1d84..4265d34074e9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -29,15 +29,21 @@ import org.apache.spark.sql.types.{IntegerType, LongType} class InferFiltersFromConstraintsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { - val batches = - Batch("InferAndPushDownFilters", FixedPoint(100), - PushPredicateThroughJoin, - PushPredicateThroughNonJoin, - InferFiltersFromConstraints, - CombineFilters, - SimplifyBinaryComparison, - BooleanSimplification, - PruneFilters) :: Nil + val operatorOptimizationRuleSet = Seq( + PushDownPredicates, + BooleanSimplification, + SimplifyBinaryComparison, + PruneFilters) + + val batches = Batch( + "Operator Optimization before Inferring Filters", + FixedPoint(100), + operatorOptimizationRuleSet: _*) :: + Batch("Infer Filters", Once, InferFiltersFromConstraints) :: + Batch( + "Operator Optimization after Inferring Filters", + FixedPoint(100), + operatorOptimizationRuleSet: _*) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) @@ -316,4 +322,75 @@ class InferFiltersFromConstraintsSuite extends PlanTest { condition) } } + + test( + "SPARK-32801: Single inner join with EqualNullSafe condition: " + + "filter out values on either side on equi-join keys") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val originalQuery = + x.join(y, condition = Some(("x.a".attr <=> "y.a".attr) && ("x.a".attr > 5))).analyze + val left = x.where(IsNotNull('a) && "x.a".attr > 5) + val right = y.where(IsNotNull('a) && "y.a".attr > 5) + val correctAnswer = left.join(right, condition = Some("x.a".attr <=> "y.a".attr)).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("SPARK-32801: Infer all constraints from a chain of filters") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val originalQuery = x + .where("x.a".attr === "x.b".attr) + .join(y, condition = Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)) + .analyze + val left = x.where(IsNotNull('a) && IsNotNull('b) && "x.a".attr === "x.b".attr) + val right = y.where(IsNotNull('a) && IsNotNull('b) && "y.a".attr === "y.b".attr) + val correctAnswer = left + .join(right, condition = Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)) + .analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("SPARK-32801: Infer from right side of left semi join") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val z = testRelation.subquery('z) + val originalQuery = x + .join( + y.join( + z.where("z.a".attr > 1), + condition = Some("y.a".attr === "z.a".attr), + joinType = LeftSemi), + condition = Some("x.a".attr === "y.a".attr)) + .analyze + val correctX = x.where(IsNotNull('a) && "x.a".attr > 1) + val correctY = y.where(IsNotNull('a) && "y.a".attr > 1) + val correctZ = z.where(IsNotNull('a) && "z.a".attr > 1) + val correctAnswer = correctX + .join( + correctY.join(correctZ, condition = Some("y.a".attr === "z.a".attr), joinType = LeftSemi), + condition = Some("x.a".attr === "y.a".attr)) + .analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("SPARK-32801: Non-deterministic filters do not introduce an infinite loop") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + val originalQuery = x + .join(y, condition = Some("x.a".attr === "y.a".attr)) + .where(rand(0) === "x.a".attr) + .analyze + val left = x.where(IsNotNull('a)) + val right = y.where(IsNotNull('a)) + val correctAnswer = left + .join(right, condition = Some("x.a".attr === "y.a".attr)) + .where(rand(0) === "x.a".attr) + .analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } }