Skip to content

Commit b38404c

Browse files
committed
Avoid genereting too many predicates
1 parent 21fb7c5 commit b38404c

File tree

3 files changed

+75
-7
lines changed

3 files changed

+75
-7
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@ package org.apache.spark.sql.catalyst.optimizer
2020
import scala.collection.mutable
2121

2222
import org.apache.spark.sql.AnalysisException
23+
import org.apache.spark.sql.catalyst.InternalRow
2324
import org.apache.spark.sql.catalyst.analysis._
2425
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
2526
import org.apache.spark.sql.catalyst.expressions._
2627
import org.apache.spark.sql.catalyst.expressions.aggregate._
28+
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
2729
import org.apache.spark.sql.catalyst.plans._
2830
import org.apache.spark.sql.catalyst.plans.logical._
2931
import org.apache.spark.sql.catalyst.rules._
@@ -1380,6 +1382,14 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
13801382
*/
13811383
object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHelper {
13821384

1385+
// Used to group same side expressions to avoid generating too many duplicate predicates.
1386+
private case class SameSide(exps: Seq[Expression]) extends CodegenFallback {
1387+
override def children: Seq[Expression] = exps
1388+
override def nullable: Boolean = true
1389+
override def dataType: DataType = throw new UnsupportedOperationException
1390+
override def eval(input: InternalRow): Any = throw new UnsupportedOperationException
1391+
}
1392+
13831393
/**
13841394
* Rewrite pattern:
13851395
* 1. (a && b) || c --> (a || c) && (b || c)
@@ -1408,8 +1418,43 @@ object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHel
14081418
}
14091419
}
14101420

1411-
private def maybeWithFilter(joinCondition: Seq[Expression], plan: LogicalPlan) = {
1412-
(joinCondition.reduceLeftOption(And).reduceLeftOption(And), plan) match {
1421+
/**
1422+
* Split And expression by single side references. For example,
1423+
* t1.a > 1 and t1.a < 10 and t2.a < 10 -->
1424+
* SameSide(t1.a > 1, t1.a < 10) and SameSide(t2.a < 10)
1425+
*/
1426+
private def splitAndExp(and: And, outputSet: AttributeSet) = {
1427+
val (leftSide, rightSide) =
1428+
splitConjunctivePredicates(and).partition(_.references.subsetOf(outputSet))
1429+
Seq(SameSide(leftSide), SameSide(rightSide)).filter(_.exps.nonEmpty).reduceLeft(And)
1430+
}
1431+
1432+
private def splitCondition(condition: Expression, outputSet: AttributeSet): Expression = {
1433+
condition.transformUp {
1434+
case Or(a: And, b: And) =>
1435+
Or(splitAndExp(a, outputSet), splitAndExp(b, outputSet))
1436+
case Or(a: And, b) =>
1437+
Or(splitAndExp(a, outputSet), b)
1438+
case Or(a, b: And) =>
1439+
Or(a, splitAndExp(b, outputSet))
1440+
}
1441+
}
1442+
1443+
// Restore expressions from SameSide.
1444+
private def restoreExps(condition: Expression): Expression = {
1445+
condition match {
1446+
case SameSide(exps) =>
1447+
exps.reduceLeft(And)
1448+
case Or(a, b) =>
1449+
Or(restoreExps(a), restoreExps(b))
1450+
case And(a, b) =>
1451+
And(restoreExps(a), restoreExps(b))
1452+
case other => other
1453+
}
1454+
}
1455+
1456+
private def maybeWithFilter(joinCondition: Option[Expression], plan: LogicalPlan) = {
1457+
(joinCondition, plan) match {
14131458
case (Some(condition), filter: Filter) if condition.semanticEquals(filter.condition) =>
14141459
plan
14151460
case (Some(condition), _) =>
@@ -1424,15 +1469,18 @@ object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHel
14241469
val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = {
14251470
case j @ Join(left, right, joinType, Some(joinCondition), hint) =>
14261471

1427-
val pushDownCandidates = splitConjunctivePredicates(rewriteToCNF(joinCondition))
1428-
.filter(_.deterministic)
1472+
val pushDownCandidates =
1473+
splitConjunctivePredicates(rewriteToCNF(splitCondition(joinCondition, left.outputSet)))
1474+
.filter(_.deterministic)
14291475
val (leftEvaluateCondition, rest) =
14301476
pushDownCandidates.partition(_.references.subsetOf(left.outputSet))
14311477
val (rightEvaluateCondition, _) =
14321478
rest.partition(expr => expr.references.subsetOf(right.outputSet))
14331479

1434-
val newLeft = maybeWithFilter(leftEvaluateCondition, left)
1435-
val newRight = maybeWithFilter(rightEvaluateCondition, right)
1480+
val newLeft =
1481+
maybeWithFilter(leftEvaluateCondition.reduceLeftOption(And).map(restoreExps), left)
1482+
val newRight =
1483+
maybeWithFilter(rightEvaluateCondition.reduceLeftOption(And).map(restoreExps), right)
14361484

14371485
joinType match {
14381486
case _: InnerLike | LeftSemi =>

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ object SQLConf {
554554
.intConf
555555
.checkValue(_ >= 0,
556556
"The depth of the maximum rewriting conjunction normal form must be positive.")
557-
.createWithDefault(6)
557+
.createWithDefault(10)
558558

559559
val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals")
560560
.internal()

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,4 +1334,24 @@ class FilterPushdownSuite extends PlanTest {
13341334

13351335
comparePlans(optimized, correctAnswer)
13361336
}
1337+
1338+
test("inner join: rewrite to conjunctive normal form avoid genereting too many predicates") {
1339+
val x = testRelation.subquery('x)
1340+
val y = testRelation.subquery('y)
1341+
1342+
val originalQuery = {
1343+
x.join(y, condition = Some(("x.b".attr === "y.b".attr)
1344+
&& ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5))
1345+
|| (("y.a".attr > 2) && ("y.c".attr < 1)))))
1346+
}
1347+
1348+
val optimized = Optimize.execute(originalQuery.analyze)
1349+
val left = testRelation.subquery('x)
1350+
val right = testRelation.where('c <= 5 || ('a > 2 && 'c < 1)).subquery('y)
1351+
val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr
1352+
&& ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5))
1353+
|| (("y.a".attr > 2) && ("y.c".attr < 1))))).analyze
1354+
1355+
comparePlans(optimized, correctAnswer)
1356+
}
13371357
}

0 commit comments

Comments
 (0)