@@ -20,10 +20,12 @@ package org.apache.spark.sql.catalyst.optimizer
2020import scala .collection .mutable
2121
2222import org .apache .spark .sql .AnalysisException
23+ import org .apache .spark .sql .catalyst .InternalRow
2324import org .apache .spark .sql .catalyst .analysis ._
2425import org .apache .spark .sql .catalyst .catalog .{InMemoryCatalog , SessionCatalog }
2526import org .apache .spark .sql .catalyst .expressions ._
2627import org .apache .spark .sql .catalyst .expressions .aggregate ._
28+ import org .apache .spark .sql .catalyst .expressions .codegen .CodegenFallback
2729import org .apache .spark .sql .catalyst .plans ._
2830import org .apache .spark .sql .catalyst .plans .logical ._
2931import org .apache .spark .sql .catalyst .rules ._
@@ -1380,6 +1382,14 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
13801382 */
13811383object PushPredicateThroughJoinByCNF extends Rule [LogicalPlan ] with PredicateHelper {
13821384
1385+ // Used to group same side expressions to avoid generating too many duplicate predicates.
1386+ private case class SameSide (exps : Seq [Expression ]) extends CodegenFallback {
1387+ override def children : Seq [Expression ] = exps
1388+ override def nullable : Boolean = true
1389+ override def dataType : DataType = throw new UnsupportedOperationException
1390+ override def eval (input : InternalRow ): Any = throw new UnsupportedOperationException
1391+ }
1392+
13831393 /**
13841394 * Rewrite pattern:
13851395 * 1. (a && b) || c --> (a || c) && (b || c)
@@ -1408,8 +1418,43 @@ object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHel
14081418 }
14091419 }
14101420
1411- private def maybeWithFilter (joinCondition : Seq [Expression ], plan : LogicalPlan ) = {
1412- (joinCondition.reduceLeftOption(And ).reduceLeftOption(And ), plan) match {
1421+ /**
1422+ * Split And expression by single side references. For example,
1423+ * t1.a > 1 and t1.a < 10 and t2.a < 10 -->
1424+ * SameSide(t1.a > 1, t1.a < 10) and SameSide(t2.a < 10)
1425+ */
1426+ private def splitAndExp (and : And , outputSet : AttributeSet ) = {
1427+ val (leftSide, rightSide) =
1428+ splitConjunctivePredicates(and).partition(_.references.subsetOf(outputSet))
1429+ Seq (SameSide (leftSide), SameSide (rightSide)).filter(_.exps.nonEmpty).reduceLeft(And )
1430+ }
1431+
1432+ private def splitCondition (condition : Expression , outputSet : AttributeSet ): Expression = {
1433+ condition.transformUp {
1434+ case Or (a : And , b : And ) =>
1435+ Or (splitAndExp(a, outputSet), splitAndExp(b, outputSet))
1436+ case Or (a : And , b) =>
1437+ Or (splitAndExp(a, outputSet), b)
1438+ case Or (a, b : And ) =>
1439+ Or (a, splitAndExp(b, outputSet))
1440+ }
1441+ }
1442+
1443+ // Restore expressions from SameSide.
1444+ private def restoreExps (condition : Expression ): Expression = {
1445+ condition match {
1446+ case SameSide (exps) =>
1447+ exps.reduceLeft(And )
1448+ case Or (a, b) =>
1449+ Or (restoreExps(a), restoreExps(b))
1450+ case And (a, b) =>
1451+ And (restoreExps(a), restoreExps(b))
1452+ case other => other
1453+ }
1454+ }
1455+
1456+ private def maybeWithFilter (joinCondition : Option [Expression ], plan : LogicalPlan ) = {
1457+ (joinCondition, plan) match {
14131458 case (Some (condition), filter : Filter ) if condition.semanticEquals(filter.condition) =>
14141459 plan
14151460 case (Some (condition), _) =>
@@ -1424,15 +1469,18 @@ object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHel
14241469 val applyLocally : PartialFunction [LogicalPlan , LogicalPlan ] = {
14251470 case j @ Join (left, right, joinType, Some (joinCondition), hint) =>
14261471
1427- val pushDownCandidates = splitConjunctivePredicates(rewriteToCNF(joinCondition))
1428- .filter(_.deterministic)
1472+ val pushDownCandidates =
1473+ splitConjunctivePredicates(rewriteToCNF(splitCondition(joinCondition, left.outputSet)))
1474+ .filter(_.deterministic)
14291475 val (leftEvaluateCondition, rest) =
14301476 pushDownCandidates.partition(_.references.subsetOf(left.outputSet))
14311477 val (rightEvaluateCondition, _) =
14321478 rest.partition(expr => expr.references.subsetOf(right.outputSet))
14331479
1434- val newLeft = maybeWithFilter(leftEvaluateCondition, left)
1435- val newRight = maybeWithFilter(rightEvaluateCondition, right)
1480+ val newLeft =
1481+ maybeWithFilter(leftEvaluateCondition.reduceLeftOption(And ).map(restoreExps), left)
1482+ val newRight =
1483+ maybeWithFilter(rightEvaluateCondition.reduceLeftOption(And ).map(restoreExps), right)
14361484
14371485 joinType match {
14381486 case _ : InnerLike | LeftSemi =>
0 commit comments