@@ -20,12 +20,10 @@ package org.apache.spark.sql.catalyst.optimizer
2020import scala .collection .mutable
2121
2222import org .apache .spark .sql .AnalysisException
23- import org .apache .spark .sql .catalyst .InternalRow
2423import org .apache .spark .sql .catalyst .analysis ._
2524import org .apache .spark .sql .catalyst .catalog .{InMemoryCatalog , SessionCatalog }
2625import org .apache .spark .sql .catalyst .expressions ._
2726import org .apache .spark .sql .catalyst .expressions .aggregate ._
28- import org .apache .spark .sql .catalyst .expressions .codegen .CodegenFallback
2927import org .apache .spark .sql .catalyst .plans ._
3028import org .apache .spark .sql .catalyst .plans .logical ._
3129import org .apache .spark .sql .catalyst .rules ._
@@ -121,6 +119,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
121119 InferFiltersFromConstraints ) ::
122120 Batch (" Operator Optimization after Inferring Filters" , fixedPoint,
123121 rulesWithoutInferFiltersFromConstraints : _* ) ::
122+ // Set strategy to Once to avoid pushing filter every time because we do not change the
123+ // join condition.
124124 Batch (" Push predicate through join by conjunctive normal form" , Once ,
125125 PushPredicateThroughJoinByCNF ) :: Nil
126126 }
@@ -1381,80 +1381,65 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
13811381 * more predicate.
13821382 */
13831383object PushPredicateThroughJoinByCNF extends Rule [LogicalPlan ] with PredicateHelper {
1384-
1385- // Used to group same side expressions to avoid generating too many duplicate predicates.
1386- private case class SameSide (exps : Seq [Expression ]) extends CodegenFallback {
1387- override def children : Seq [Expression ] = exps
1388- override def nullable : Boolean = true
1389- override def dataType : DataType = throw new UnsupportedOperationException
1390- override def eval (input : InternalRow ): Any = throw new UnsupportedOperationException
1391- }
1392-
13931384 /**
13941385 * Rewrite pattern:
13951386 * 1. (a && b) || c --> (a || c) && (b || c)
13961387 * 2. a || (b && c) --> (a || b) && (a || c)
1397- * 3. !(a || b) --> !a && !b
1388+ *
1389+ * To avoid generating too many predicates, we first group the filter columns from the same table.
13981390 */
1399- private def rewriteToCNF (condition : Expression , depth : Int = 0 ): Expression = {
1391+ private def toCNF (condition : Expression , depth : Int = 0 ): Expression = {
14001392 if (depth < SQLConf .get.maxRewritingCNFDepth) {
1401- val nextDepth = depth + 1
14021393 condition match {
1403- case Or (And (a, b), c) =>
1404- And (rewriteToCNF(Or (rewriteToCNF(a, nextDepth), rewriteToCNF(c, nextDepth)), nextDepth),
1405- rewriteToCNF(Or (rewriteToCNF(b, nextDepth), rewriteToCNF(c, nextDepth)), nextDepth))
1406- case Or (a, And (b, c)) =>
1407- And (rewriteToCNF(Or (rewriteToCNF(a, nextDepth), rewriteToCNF(b, nextDepth)), nextDepth),
1408- rewriteToCNF(Or (rewriteToCNF(a, nextDepth), rewriteToCNF(c, nextDepth)), nextDepth))
1409- case Not (Or (a, b)) =>
1410- And (rewriteToCNF(Not (rewriteToCNF(a, nextDepth)), nextDepth),
1411- rewriteToCNF(Not (rewriteToCNF(b, nextDepth)), nextDepth))
1412- case And (a, b) =>
1413- And (rewriteToCNF(a, nextDepth), rewriteToCNF(b, nextDepth))
1414- case other => other
1415- }
1416- } else {
1417- condition
1418- }
1419- }
1394+ case or @ Or (left : And , right : And ) =>
1395+ val lhs = splitConjunctivePredicates(left).groupBy(_.references.map(_.qualifier))
1396+ val rhs = splitConjunctivePredicates(right).groupBy(_.references.map(_.qualifier))
1397+ if (lhs.size > 1 ) {
1398+ lhs.values.map(_.reduceLeft(And )).map { c =>
1399+ toCNF(Or (toCNF(c, depth + 1 ), toCNF(right, depth + 1 )), depth + 1 )
1400+ }.reduce(And )
1401+ } else if (rhs.size > 1 ) {
1402+ rhs.values.map(_.reduceLeft(And )).map { c =>
1403+ toCNF(Or (toCNF(left, depth + 1 ), toCNF(c, depth + 1 )), depth + 1 )
1404+ }.reduce(And )
1405+ } else {
1406+ or
1407+ }
14201408
1421- /**
1422- * Split And expression by single side references. For example,
1423- * t1.a > 1 and t1.a < 10 and t2.a < 10 -->
1424- * SameSide(t1.a > 1, t1.a < 10) and SameSide(t2.a < 10)
1425- */
1426- private def splitAndExp (and : And , outputSet : AttributeSet ) = {
1427- val (leftSide, rightSide) =
1428- splitConjunctivePredicates(and).partition(_.references.subsetOf(outputSet))
1429- Seq (SameSide (leftSide), SameSide (rightSide)).filter(_.exps.nonEmpty).reduceLeft(And )
1430- }
1409+ case or @ Or (left : And , right) =>
1410+ val lhs = splitConjunctivePredicates(left).groupBy(_.references.map(_.qualifier))
1411+ if (lhs.size > 1 ) {
1412+ lhs.values.map(_.reduceLeft(And )).map {
1413+ c => toCNF(Or (toCNF(c, depth + 1 ), toCNF(right, depth + 1 )), depth + 1 )
1414+ }.reduce(And )
1415+ } else {
1416+ or
1417+ }
14311418
1432- private def splitCondition (condition : Expression , outputSet : AttributeSet ): Expression = {
1433- condition.transformUp {
1434- case Or (a : And , b : And ) =>
1435- Or (splitAndExp(a, outputSet), splitAndExp(b, outputSet))
1436- case Or (a : And , b) =>
1437- Or (splitAndExp(a, outputSet), b)
1438- case Or (a, b : And ) =>
1439- Or (a, splitAndExp(b, outputSet))
1440- }
1441- }
1419+ case or @ Or (left, right : And ) =>
1420+ val rhs = splitConjunctivePredicates(right).groupBy(_.references.map(_.qualifier))
1421+ if (rhs.size > 1 ) {
1422+ rhs.values.map(_.reduceLeft(And )).map { c =>
1423+ toCNF(Or (toCNF(left, depth + 1 ), toCNF(c, depth + 1 )), depth + 1 )
1424+ }.reduce(And )
1425+ } else {
1426+ or
1427+ }
14421428
1443- // Restore expressions from SameSide.
1444- private def restoreExps (condition : Expression ): Expression = {
1445- condition match {
1446- case SameSide (exps) =>
1447- exps.reduceLeft(And )
1448- case Or (a, b) =>
1449- Or (restoreExps(a), restoreExps(b))
1450- case And (a, b) =>
1451- And (restoreExps(a), restoreExps(b))
1452- case other => other
1429+ case And (left, right) =>
1430+ And (toCNF(left, depth + 1 ), toCNF(right, depth + 1 ))
1431+
1432+ case other =>
1433+ other
1434+ }
1435+ } else {
1436+ condition
14531437 }
14541438 }
14551439
14561440 private def maybeWithFilter (joinCondition : Option [Expression ], plan : LogicalPlan ) = {
14571441 (joinCondition, plan) match {
1442+ // Avoid adding the same filter.
14581443 case (Some (condition), filter : Filter ) if condition.semanticEquals(filter.condition) =>
14591444 plan
14601445 case (Some (condition), _) =>
@@ -1470,17 +1455,14 @@ object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHel
14701455 case j @ Join (left, right, joinType, Some (joinCondition), hint) =>
14711456
14721457 val pushDownCandidates =
1473- splitConjunctivePredicates(rewriteToCNF(splitCondition(joinCondition, left.outputSet)))
1474- .filter(_.deterministic)
1458+ splitConjunctivePredicates(toCNF(joinCondition)).filter(_.deterministic)
14751459 val (leftEvaluateCondition, rest) =
14761460 pushDownCandidates.partition(_.references.subsetOf(left.outputSet))
14771461 val (rightEvaluateCondition, _) =
14781462 rest.partition(expr => expr.references.subsetOf(right.outputSet))
14791463
1480- val newLeft =
1481- maybeWithFilter(leftEvaluateCondition.reduceLeftOption(And ).map(restoreExps), left)
1482- val newRight =
1483- maybeWithFilter(rightEvaluateCondition.reduceLeftOption(And ).map(restoreExps), right)
1464+ val newLeft = maybeWithFilter(leftEvaluateCondition.reduceLeftOption(And ), left)
1465+ val newRight = maybeWithFilter(rightEvaluateCondition.reduceLeftOption(And ), right)
14841466
14851467 joinType match {
14861468 case _ : InnerLike | LeftSemi =>
0 commit comments