Skip to content

Commit e8f2471

Browse files
committed
Remove maybeWithFilter and add PushPredicateThroughJoinByCNF to blacklistedOnceBatches
1 parent 6c44d64 commit e8f2471

File tree

3 files changed

+30
-37
lines changed

3 files changed

+30
-37
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
5151
override protected val blacklistedOnceBatches: Set[String] =
5252
Set(
5353
"PartitionPruning",
54-
"Extract Python UDFs")
54+
"Extract Python UDFs",
55+
"Push predicate through join by CNF")
5556

5657
protected def fixedPoint =
5758
FixedPoint(
@@ -121,7 +122,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
121122
rulesWithoutInferFiltersFromConstraints: _*) ::
122123
// Set strategy to Once to avoid pushing filter every time because we do not change the
123124
// join condition.
124-
Batch("Push predicate through join by conjunctive normal form", Once,
125+
Batch("Push predicate through join by CNF", Once,
125126
PushPredicateThroughJoinByCNF) :: Nil
126127
}
127128

@@ -1383,10 +1384,11 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
13831384
object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHelper {
13841385
/**
13851386
* Rewrite pattern:
1386-
* 1. (a && b) || c --> (a || c) && (b || c)
1387-
* 2. a || (b && c) --> (a || b) && (a || c)
1387+
* 1. (a && b) || (c && d) --> (a || c) && (a || d) && (b || c) && (b && d)
1388+
* 2. (a && b) || c --> (a || c) && (b || c)
1389+
* 3. a || (b && c) --> (a || b) && (a || c)
13881390
*
1389-
* To avoid generating too many predicates, we first group the filter columns from the same table.
1391+
* To avoid generating too many predicates, we first group the columns from the same table.
13901392
*/
13911393
private def toCNF(condition: Expression, depth: Int = 0): Expression = {
13921394
if (depth < SQLConf.get.maxRewritingCNFDepth) {
@@ -1395,12 +1397,12 @@ object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHel
13951397
val lhs = splitConjunctivePredicates(left).groupBy(_.references.map(_.qualifier))
13961398
val rhs = splitConjunctivePredicates(right).groupBy(_.references.map(_.qualifier))
13971399
if (lhs.size > 1) {
1398-
lhs.values.map(_.reduceLeft(And)).map { c =>
1399-
toCNF(Or(toCNF(c, depth + 1), toCNF(right, depth + 1)), depth + 1)
1400+
lhs.values.map(_.reduceLeft(And)).map { e =>
1401+
toCNF(Or(toCNF(e, depth + 1), toCNF(right, depth + 1)), depth + 1)
14001402
}.reduce(And)
14011403
} else if (rhs.size > 1) {
1402-
rhs.values.map(_.reduceLeft(And)).map { c =>
1403-
toCNF(Or(toCNF(left, depth + 1), toCNF(c, depth + 1)), depth + 1)
1404+
rhs.values.map(_.reduceLeft(And)).map { e =>
1405+
toCNF(Or(toCNF(left, depth + 1), toCNF(e, depth + 1)), depth + 1)
14041406
}.reduce(And)
14051407
} else {
14061408
or
@@ -1409,8 +1411,8 @@ object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHel
14091411
case or @ Or(left: And, right) =>
14101412
val lhs = splitConjunctivePredicates(left).groupBy(_.references.map(_.qualifier))
14111413
if (lhs.size > 1) {
1412-
lhs.values.map(_.reduceLeft(And)).map {
1413-
c => toCNF(Or(toCNF(c, depth + 1), toCNF(right, depth + 1)), depth + 1)
1414+
lhs.values.map(_.reduceLeft(And)).map { e =>
1415+
toCNF(Or(toCNF(e, depth + 1), toCNF(right, depth + 1)), depth + 1)
14141416
}.reduce(And)
14151417
} else {
14161418
or
@@ -1419,8 +1421,8 @@ object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHel
14191421
case or @ Or(left, right: And) =>
14201422
val rhs = splitConjunctivePredicates(right).groupBy(_.references.map(_.qualifier))
14211423
if (rhs.size > 1) {
1422-
rhs.values.map(_.reduceLeft(And)).map { c =>
1423-
toCNF(Or(toCNF(left, depth + 1), toCNF(c, depth + 1)), depth + 1)
1424+
rhs.values.map(_.reduceLeft(And)).map { e =>
1425+
toCNF(Or(toCNF(left, depth + 1), toCNF(e, depth + 1)), depth + 1)
14241426
}.reduce(And)
14251427
} else {
14261428
or
@@ -1437,32 +1439,19 @@ object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHel
14371439
}
14381440
}
14391441

1440-
private def maybeWithFilter(joinCondition: Option[Expression], plan: LogicalPlan) = {
1441-
(joinCondition, plan) match {
1442-
// Avoid adding the same filter.
1443-
case (Some(condition), filter: Filter) if condition.semanticEquals(filter.condition) =>
1444-
plan
1445-
case (Some(condition), _) =>
1446-
Filter(condition, plan)
1447-
case _ =>
1448-
plan
1449-
}
1450-
}
1451-
1452-
def apply(plan: LogicalPlan): LogicalPlan = plan transform applyLocally
1453-
1454-
val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = {
1442+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
14551443
case j @ Join(left, right, joinType, Some(joinCondition), hint) =>
1456-
14571444
val pushDownCandidates =
14581445
splitConjunctivePredicates(toCNF(joinCondition)).filter(_.deterministic)
1459-
val (leftEvaluateCondition, rest) =
1446+
val (leftFilterConditions, rest) =
14601447
pushDownCandidates.partition(_.references.subsetOf(left.outputSet))
1461-
val (rightEvaluateCondition, _) =
1448+
val (rightFilterConditions, _) =
14621449
rest.partition(expr => expr.references.subsetOf(right.outputSet))
14631450

1464-
val newLeft = maybeWithFilter(leftEvaluateCondition.reduceLeftOption(And), left)
1465-
val newRight = maybeWithFilter(rightEvaluateCondition.reduceLeftOption(And), right)
1451+
val newLeft = leftFilterConditions.
1452+
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
1453+
val newRight = rightFilterConditions.
1454+
reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
14661455

14671456
joinType match {
14681457
case _: InnerLike | LeftSemi =>

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -548,8 +548,8 @@ object SQLConf {
548548
buildConf("spark.sql.maxRewritingCNFDepth")
549549
.internal()
550550
.doc("The maximum depth of rewriting a join condition to conjunctive normal form " +
551-
"expression. The deeper, the more predicate may be found, but the optimization time " +
552-
"will increase. The default is 6. By setting this value to 0 this feature can be disabled.")
551+
"expression. The deeper, the more predicate may be found, but the optimization time will " +
552+
"increase. The default is 10. By setting this value to 0 this feature can be disabled.")
553553
.version("3.1.0")
554554
.intConf
555555
.checkValue(_ >= 0,

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ import org.apache.spark.unsafe.types.CalendarInterval
3232
class FilterPushdownSuite extends PlanTest {
3333

3434
object Optimize extends RuleExecutor[LogicalPlan] {
35+
36+
override protected val blacklistedOnceBatches: Set[String] =
37+
Set("Push predicate through join by CNF")
38+
3539
val batches =
3640
Batch("Subqueries", Once,
3741
EliminateSubqueryAliases) ::
@@ -41,7 +45,7 @@ class FilterPushdownSuite extends PlanTest {
4145
BooleanSimplification,
4246
PushPredicateThroughJoin,
4347
CollapseProject) ::
44-
Batch("PushPredicateThroughJoinByCNF", Once,
48+
Batch("Push predicate through join by CNF", Once,
4549
PushPredicateThroughJoinByCNF) :: Nil
4650
}
4751

@@ -1336,7 +1340,7 @@ class FilterPushdownSuite extends PlanTest {
13361340
comparePlans(optimized, correctAnswer)
13371341
}
13381342

1339-
test("inner join: rewrite to conjunctive normal form avoid genereting too many predicates") {
1343+
test("inner join: rewrite to conjunctive normal form avoid generating too many predicates") {
13401344
val x = testRelation.subquery('x)
13411345
val y = testRelation.subquery('y)
13421346

0 commit comments

Comments
 (0)