Skip to content

Commit 6c44d64

Browse files
committed
Remove SameSide
1 parent b38404c commit 6c44d64

File tree

2 files changed

+78
-67
lines changed

2 files changed

+78
-67
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 49 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,10 @@ package org.apache.spark.sql.catalyst.optimizer
2020
import scala.collection.mutable
2121

2222
import org.apache.spark.sql.AnalysisException
23-
import org.apache.spark.sql.catalyst.InternalRow
2423
import org.apache.spark.sql.catalyst.analysis._
2524
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
2625
import org.apache.spark.sql.catalyst.expressions._
2726
import org.apache.spark.sql.catalyst.expressions.aggregate._
28-
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
2927
import org.apache.spark.sql.catalyst.plans._
3028
import org.apache.spark.sql.catalyst.plans.logical._
3129
import org.apache.spark.sql.catalyst.rules._
@@ -121,6 +119,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
121119
InferFiltersFromConstraints) ::
122120
Batch("Operator Optimization after Inferring Filters", fixedPoint,
123121
rulesWithoutInferFiltersFromConstraints: _*) ::
122+
// Set strategy to Once to avoid pushing filter every time because we do not change the
123+
// join condition.
124124
Batch("Push predicate through join by conjunctive normal form", Once,
125125
PushPredicateThroughJoinByCNF) :: Nil
126126
}
@@ -1381,80 +1381,65 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
13811381
* more predicate.
13821382
*/
13831383
object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHelper {
1384-
1385-
// Used to group same side expressions to avoid generating too many duplicate predicates.
1386-
private case class SameSide(exps: Seq[Expression]) extends CodegenFallback {
1387-
override def children: Seq[Expression] = exps
1388-
override def nullable: Boolean = true
1389-
override def dataType: DataType = throw new UnsupportedOperationException
1390-
override def eval(input: InternalRow): Any = throw new UnsupportedOperationException
1391-
}
1392-
13931384
/**
13941385
* Rewrite pattern:
13951386
* 1. (a && b) || c --> (a || c) && (b || c)
13961387
* 2. a || (b && c) --> (a || b) && (a || c)
1397-
* 3. !(a || b) --> !a && !b
1388+
*
1389+
* To avoid generating too many predicates, we first group the filter columns from the same table.
13981390
*/
1399-
private def rewriteToCNF(condition: Expression, depth: Int = 0): Expression = {
1391+
private def toCNF(condition: Expression, depth: Int = 0): Expression = {
14001392
if (depth < SQLConf.get.maxRewritingCNFDepth) {
1401-
val nextDepth = depth + 1
14021393
condition match {
1403-
case Or(And(a, b), c) =>
1404-
And(rewriteToCNF(Or(rewriteToCNF(a, nextDepth), rewriteToCNF(c, nextDepth)), nextDepth),
1405-
rewriteToCNF(Or(rewriteToCNF(b, nextDepth), rewriteToCNF(c, nextDepth)), nextDepth))
1406-
case Or(a, And(b, c)) =>
1407-
And(rewriteToCNF(Or(rewriteToCNF(a, nextDepth), rewriteToCNF(b, nextDepth)), nextDepth),
1408-
rewriteToCNF(Or(rewriteToCNF(a, nextDepth), rewriteToCNF(c, nextDepth)), nextDepth))
1409-
case Not(Or(a, b)) =>
1410-
And(rewriteToCNF(Not(rewriteToCNF(a, nextDepth)), nextDepth),
1411-
rewriteToCNF(Not(rewriteToCNF(b, nextDepth)), nextDepth))
1412-
case And(a, b) =>
1413-
And(rewriteToCNF(a, nextDepth), rewriteToCNF(b, nextDepth))
1414-
case other => other
1415-
}
1416-
} else {
1417-
condition
1418-
}
1419-
}
1394+
case or @ Or(left: And, right: And) =>
1395+
val lhs = splitConjunctivePredicates(left).groupBy(_.references.map(_.qualifier))
1396+
val rhs = splitConjunctivePredicates(right).groupBy(_.references.map(_.qualifier))
1397+
if (lhs.size > 1) {
1398+
lhs.values.map(_.reduceLeft(And)).map { c =>
1399+
toCNF(Or(toCNF(c, depth + 1), toCNF(right, depth + 1)), depth + 1)
1400+
}.reduce(And)
1401+
} else if (rhs.size > 1) {
1402+
rhs.values.map(_.reduceLeft(And)).map { c =>
1403+
toCNF(Or(toCNF(left, depth + 1), toCNF(c, depth + 1)), depth + 1)
1404+
}.reduce(And)
1405+
} else {
1406+
or
1407+
}
14201408

1421-
/**
1422-
* Split And expression by single side references. For example,
1423-
* t1.a > 1 and t1.a < 10 and t2.a < 10 -->
1424-
* SameSide(t1.a > 1, t1.a < 10) and SameSide(t2.a < 10)
1425-
*/
1426-
private def splitAndExp(and: And, outputSet: AttributeSet) = {
1427-
val (leftSide, rightSide) =
1428-
splitConjunctivePredicates(and).partition(_.references.subsetOf(outputSet))
1429-
Seq(SameSide(leftSide), SameSide(rightSide)).filter(_.exps.nonEmpty).reduceLeft(And)
1430-
}
1409+
case or @ Or(left: And, right) =>
1410+
val lhs = splitConjunctivePredicates(left).groupBy(_.references.map(_.qualifier))
1411+
if (lhs.size > 1) {
1412+
lhs.values.map(_.reduceLeft(And)).map {
1413+
c => toCNF(Or(toCNF(c, depth + 1), toCNF(right, depth + 1)), depth + 1)
1414+
}.reduce(And)
1415+
} else {
1416+
or
1417+
}
14311418

1432-
private def splitCondition(condition: Expression, outputSet: AttributeSet): Expression = {
1433-
condition.transformUp {
1434-
case Or(a: And, b: And) =>
1435-
Or(splitAndExp(a, outputSet), splitAndExp(b, outputSet))
1436-
case Or(a: And, b) =>
1437-
Or(splitAndExp(a, outputSet), b)
1438-
case Or(a, b: And) =>
1439-
Or(a, splitAndExp(b, outputSet))
1440-
}
1441-
}
1419+
case or @ Or(left, right: And) =>
1420+
val rhs = splitConjunctivePredicates(right).groupBy(_.references.map(_.qualifier))
1421+
if (rhs.size > 1) {
1422+
rhs.values.map(_.reduceLeft(And)).map { c =>
1423+
toCNF(Or(toCNF(left, depth + 1), toCNF(c, depth + 1)), depth + 1)
1424+
}.reduce(And)
1425+
} else {
1426+
or
1427+
}
14421428

1443-
// Restore expressions from SameSide.
1444-
private def restoreExps(condition: Expression): Expression = {
1445-
condition match {
1446-
case SameSide(exps) =>
1447-
exps.reduceLeft(And)
1448-
case Or(a, b) =>
1449-
Or(restoreExps(a), restoreExps(b))
1450-
case And(a, b) =>
1451-
And(restoreExps(a), restoreExps(b))
1452-
case other => other
1429+
case And(left, right) =>
1430+
And(toCNF(left, depth + 1), toCNF(right, depth + 1))
1431+
1432+
case other =>
1433+
other
1434+
}
1435+
} else {
1436+
condition
14531437
}
14541438
}
14551439

14561440
private def maybeWithFilter(joinCondition: Option[Expression], plan: LogicalPlan) = {
14571441
(joinCondition, plan) match {
1442+
// Avoid adding the same filter.
14581443
case (Some(condition), filter: Filter) if condition.semanticEquals(filter.condition) =>
14591444
plan
14601445
case (Some(condition), _) =>
@@ -1470,17 +1455,14 @@ object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHel
14701455
case j @ Join(left, right, joinType, Some(joinCondition), hint) =>
14711456

14721457
val pushDownCandidates =
1473-
splitConjunctivePredicates(rewriteToCNF(splitCondition(joinCondition, left.outputSet)))
1474-
.filter(_.deterministic)
1458+
splitConjunctivePredicates(toCNF(joinCondition)).filter(_.deterministic)
14751459
val (leftEvaluateCondition, rest) =
14761460
pushDownCandidates.partition(_.references.subsetOf(left.outputSet))
14771461
val (rightEvaluateCondition, _) =
14781462
rest.partition(expr => expr.references.subsetOf(right.outputSet))
14791463

1480-
val newLeft =
1481-
maybeWithFilter(leftEvaluateCondition.reduceLeftOption(And).map(restoreExps), left)
1482-
val newRight =
1483-
maybeWithFilter(rightEvaluateCondition.reduceLeftOption(And).map(restoreExps), right)
1464+
val newLeft = maybeWithFilter(leftEvaluateCondition.reduceLeftOption(And), left)
1465+
val newRight = maybeWithFilter(rightEvaluateCondition.reduceLeftOption(And), right)
14841466

14851467
joinType match {
14861468
case _: InnerLike | LeftSemi =>

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._
2525
import org.apache.spark.sql.catalyst.plans._
2626
import org.apache.spark.sql.catalyst.plans.logical._
2727
import org.apache.spark.sql.catalyst.rules._
28+
import org.apache.spark.sql.internal.SQLConf
2829
import org.apache.spark.sql.types.{BooleanType, IntegerType}
2930
import org.apache.spark.unsafe.types.CalendarInterval
3031

@@ -1354,4 +1355,32 @@ class FilterPushdownSuite extends PlanTest {
13541355

13551356
comparePlans(optimized, correctAnswer)
13561357
}
1358+
1359+
test(s"Disable rewrite to CNF by setting ${SQLConf.MAX_REWRITING_CNF_DEPTH.key}=0") {
1360+
val x = testRelation.subquery('x)
1361+
val y = testRelation.subquery('y)
1362+
1363+
val originalQuery = {
1364+
x.join(y, condition = Some(("x.b".attr === "y.b".attr)
1365+
&& ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5))
1366+
|| (("y.a".attr > 2) && ("y.c".attr < 1)))))
1367+
}
1368+
1369+
Seq(0, 10).foreach { depth =>
1370+
withSQLConf(SQLConf.MAX_REWRITING_CNF_DEPTH.key -> depth.toString) {
1371+
val optimized = Optimize.execute(originalQuery.analyze)
1372+
val (left, right) = if (depth == 0) {
1373+
(testRelation.subquery('x), testRelation.subquery('y))
1374+
} else {
1375+
(testRelation.subquery('x),
1376+
testRelation.where('c <= 5 || ('a > 2 && 'c < 1)).subquery('y))
1377+
}
1378+
val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr
1379+
&& ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5))
1380+
|| (("y.a".attr > 2) && ("y.c".attr < 1))))).analyze
1381+
1382+
comparePlans(optimized, correctAnswer)
1383+
}
1384+
}
1385+
}
13571386
}

0 commit comments

Comments
 (0)