@@ -51,7 +51,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
5151 override protected val blacklistedOnceBatches : Set [String ] =
5252 Set (
5353 " PartitionPruning" ,
54- " Extract Python UDFs" )
54+ " Extract Python UDFs" ,
55+ " Push predicate through join by CNF" )
5556
5657 protected def fixedPoint =
5758 FixedPoint (
@@ -121,7 +122,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
121122 rulesWithoutInferFiltersFromConstraints : _* ) ::
122123 // Set strategy to Once to avoid pushing filter every time because we do not change the
123124 // join condition.
124- Batch (" Push predicate through join by conjunctive normal form " , Once ,
125+ Batch (" Push predicate through join by CNF " , Once ,
125126 PushPredicateThroughJoinByCNF ) :: Nil
126127 }
127128
@@ -1383,10 +1384,11 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
13831384object PushPredicateThroughJoinByCNF extends Rule [LogicalPlan ] with PredicateHelper {
13841385 /**
13851386 * Rewrite pattern:
1386- * 1. (a && b) || c --> (a || c) && (b || c)
1387- * 2. a || (b && c) --> (a || b) && (a || c)
1387+ * 1. (a && b) || (c && d) --> (a || c) && (a || d) && (b || c) && (b && d)
1388+ * 2. (a && b) || c --> (a || c) && (b || c)
1389+ * 3. a || (b && c) --> (a || b) && (a || c)
13881390 *
1389- * To avoid generating too many predicates, we first group the filter columns from the same table.
1391+ * To avoid generating too many predicates, we first group the columns from the same table.
13901392 */
13911393 private def toCNF (condition : Expression , depth : Int = 0 ): Expression = {
13921394 if (depth < SQLConf .get.maxRewritingCNFDepth) {
@@ -1395,12 +1397,12 @@ object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHel
13951397 val lhs = splitConjunctivePredicates(left).groupBy(_.references.map(_.qualifier))
13961398 val rhs = splitConjunctivePredicates(right).groupBy(_.references.map(_.qualifier))
13971399 if (lhs.size > 1 ) {
1398- lhs.values.map(_.reduceLeft(And )).map { c =>
1399- toCNF(Or (toCNF(c , depth + 1 ), toCNF(right, depth + 1 )), depth + 1 )
1400+ lhs.values.map(_.reduceLeft(And )).map { e =>
1401+ toCNF(Or (toCNF(e , depth + 1 ), toCNF(right, depth + 1 )), depth + 1 )
14001402 }.reduce(And )
14011403 } else if (rhs.size > 1 ) {
1402- rhs.values.map(_.reduceLeft(And )).map { c =>
1403- toCNF(Or (toCNF(left, depth + 1 ), toCNF(c , depth + 1 )), depth + 1 )
1404+ rhs.values.map(_.reduceLeft(And )).map { e =>
1405+ toCNF(Or (toCNF(left, depth + 1 ), toCNF(e , depth + 1 )), depth + 1 )
14041406 }.reduce(And )
14051407 } else {
14061408 or
@@ -1409,8 +1411,8 @@ object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHel
14091411 case or @ Or (left : And , right) =>
14101412 val lhs = splitConjunctivePredicates(left).groupBy(_.references.map(_.qualifier))
14111413 if (lhs.size > 1 ) {
1412- lhs.values.map(_.reduceLeft(And )).map {
1413- c => toCNF(Or (toCNF(c , depth + 1 ), toCNF(right, depth + 1 )), depth + 1 )
1414+ lhs.values.map(_.reduceLeft(And )).map { e =>
1415+ toCNF(Or (toCNF(e , depth + 1 ), toCNF(right, depth + 1 )), depth + 1 )
14141416 }.reduce(And )
14151417 } else {
14161418 or
@@ -1419,8 +1421,8 @@ object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHel
14191421 case or @ Or (left, right : And ) =>
14201422 val rhs = splitConjunctivePredicates(right).groupBy(_.references.map(_.qualifier))
14211423 if (rhs.size > 1 ) {
1422- rhs.values.map(_.reduceLeft(And )).map { c =>
1423- toCNF(Or (toCNF(left, depth + 1 ), toCNF(c , depth + 1 )), depth + 1 )
1424+ rhs.values.map(_.reduceLeft(And )).map { e =>
1425+ toCNF(Or (toCNF(left, depth + 1 ), toCNF(e , depth + 1 )), depth + 1 )
14241426 }.reduce(And )
14251427 } else {
14261428 or
@@ -1437,32 +1439,19 @@ object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHel
14371439 }
14381440 }
14391441
1440- private def maybeWithFilter (joinCondition : Option [Expression ], plan : LogicalPlan ) = {
1441- (joinCondition, plan) match {
1442- // Avoid adding the same filter.
1443- case (Some (condition), filter : Filter ) if condition.semanticEquals(filter.condition) =>
1444- plan
1445- case (Some (condition), _) =>
1446- Filter (condition, plan)
1447- case _ =>
1448- plan
1449- }
1450- }
1451-
1452- def apply (plan : LogicalPlan ): LogicalPlan = plan transform applyLocally
1453-
1454- val applyLocally : PartialFunction [LogicalPlan , LogicalPlan ] = {
1442+ def apply (plan : LogicalPlan ): LogicalPlan = plan transform {
14551443 case j @ Join (left, right, joinType, Some (joinCondition), hint) =>
1456-
14571444 val pushDownCandidates =
14581445 splitConjunctivePredicates(toCNF(joinCondition)).filter(_.deterministic)
1459- val (leftEvaluateCondition , rest) =
1446+ val (leftFilterConditions , rest) =
14601447 pushDownCandidates.partition(_.references.subsetOf(left.outputSet))
1461- val (rightEvaluateCondition , _) =
1448+ val (rightFilterConditions , _) =
14621449 rest.partition(expr => expr.references.subsetOf(right.outputSet))
14631450
1464- val newLeft = maybeWithFilter(leftEvaluateCondition.reduceLeftOption(And ), left)
1465- val newRight = maybeWithFilter(rightEvaluateCondition.reduceLeftOption(And ), right)
1451+ val newLeft = leftFilterConditions.
1452+ reduceLeftOption(And ).map(Filter (_, left)).getOrElse(left)
1453+ val newRight = rightFilterConditions.
1454+ reduceLeftOption(And ).map(Filter (_, right)).getOrElse(right)
14661455
14671456 joinType match {
14681457 case _ : InnerLike | LeftSemi =>
0 commit comments