diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e59e3b999aa7..04a89760b9e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -51,7 +51,8 @@ abstract class Optimizer(catalogManager: CatalogManager) override protected val blacklistedOnceBatches: Set[String] = Set( "PartitionPruning", - "Extract Python UDFs") + "Extract Python UDFs", + "Push predicate through join by CNF") protected def fixedPoint = FixedPoint( @@ -118,7 +119,11 @@ abstract class Optimizer(catalogManager: CatalogManager) Batch("Infer Filters", Once, InferFiltersFromConstraints) :: Batch("Operator Optimization after Inferring Filters", fixedPoint, - rulesWithoutInferFiltersFromConstraints: _*) :: Nil + rulesWithoutInferFiltersFromConstraints: _*) :: + // Set strategy to Once to avoid pushing filter every time because we do not change the + // join condition. + Batch("Push predicate through join by CNF", Once, + PushPredicateThroughJoinByCNF) :: Nil } val batches = (Batch("Eliminate Distinct", Once, EliminateDistinct) :: @@ -1372,6 +1377,96 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { } } +/** + * Rewriting join condition to conjunctive normal form expression so that we can push + * more predicate. + */ +object PushPredicateThroughJoinByCNF extends Rule[LogicalPlan] with PredicateHelper { + /** + * Rewrite pattern: + * 1. (a && b) || (c && d) --> (a || c) && (a || d) && (b || c) && (b && d) + * 2. (a && b) || c --> (a || c) && (b || c) + * 3. a || (b && c) --> (a || b) && (a || c) + * + * To avoid generating too many predicates, we first group the columns from the same table. + */ + private def toCNF(condition: Expression, depth: Int = 0): Expression = { + if (depth < SQLConf.get.maxRewritingCNFDepth) { + condition match { + case or @ Or(left: And, right: And) => + val lhs = splitConjunctivePredicates(left).groupBy(_.references.map(_.qualifier)) + val rhs = splitConjunctivePredicates(right).groupBy(_.references.map(_.qualifier)) + if (lhs.size > 1) { + lhs.values.map(_.reduceLeft(And)).map { e => + toCNF(Or(toCNF(e, depth + 1), toCNF(right, depth + 1)), depth + 1) + }.reduce(And) + } else if (rhs.size > 1) { + rhs.values.map(_.reduceLeft(And)).map { e => + toCNF(Or(toCNF(left, depth + 1), toCNF(e, depth + 1)), depth + 1) + }.reduce(And) + } else { + or + } + + case or @ Or(left: And, right) => + val lhs = splitConjunctivePredicates(left).groupBy(_.references.map(_.qualifier)) + if (lhs.size > 1) { + lhs.values.map(_.reduceLeft(And)).map { e => + toCNF(Or(toCNF(e, depth + 1), toCNF(right, depth + 1)), depth + 1) + }.reduce(And) + } else { + or + } + + case or @ Or(left, right: And) => + val rhs = splitConjunctivePredicates(right).groupBy(_.references.map(_.qualifier)) + if (rhs.size > 1) { + rhs.values.map(_.reduceLeft(And)).map { e => + toCNF(Or(toCNF(left, depth + 1), toCNF(e, depth + 1)), depth + 1) + }.reduce(And) + } else { + or + } + + case And(left, right) => + And(toCNF(left, depth + 1), toCNF(right, depth + 1)) + + case other => + other + } + } else { + condition + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case j @ Join(left, right, joinType, Some(joinCondition), hint) => + val pushDownCandidates = + splitConjunctivePredicates(toCNF(joinCondition)).filter(_.deterministic) + val (leftFilterConditions, rest) = + pushDownCandidates.partition(_.references.subsetOf(left.outputSet)) + val (rightFilterConditions, _) = + rest.partition(expr => expr.references.subsetOf(right.outputSet)) + + val newLeft = leftFilterConditions. + reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) + val newRight = rightFilterConditions. + reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) + + joinType match { + case _: InnerLike | LeftSemi => + Join(newLeft, newRight, joinType, Some(joinCondition), hint) + case RightOuter => + Join(newLeft, right, RightOuter, Some(joinCondition), hint) + case LeftOuter | LeftAnti | ExistenceJoin(_) => + Join(left, newRight, joinType, Some(joinCondition), hint) + case FullOuter => j + case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node") + case UsingJoin(_, _) => sys.error("Untransformed Using join node") + } + } +} + /** * Combines two adjacent [[Limit]] operators into one, merging the * expressions into one single expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c739fa516f0c..d8ac3943d233 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -544,6 +544,18 @@ object SQLConf { .booleanConf .createWithDefault(true) + val MAX_REWRITING_CNF_DEPTH = + buildConf("spark.sql.maxRewritingCNFDepth") + .internal() + .doc("The maximum depth of rewriting a join condition to conjunctive normal form " + + "expression. The deeper, the more predicate may be found, but the optimization time will " + + "increase. The default is 10. By setting this value to 0 this feature can be disabled.") + .version("3.1.0") + .intConf + .checkValue(_ >= 0, + "The depth of the maximum rewriting conjunction normal form must be positive.") + .createWithDefault(10) + val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals") .internal() .doc("When true, string literals (including regex patterns) remain escaped in our SQL " + @@ -2845,6 +2857,8 @@ class SQLConf extends Serializable with Logging { def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED) + def maxRewritingCNFDepth: Int = getConf(MAX_REWRITING_CNF_DEPTH) + def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS) def fileCompressionFactor: Double = getConf(FILE_COMPRESSION_FACTOR) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 70e29dca46e9..d3c338c5789d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -25,12 +25,17 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, IntegerType} import org.apache.spark.unsafe.types.CalendarInterval class FilterPushdownSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { + + override protected val blacklistedOnceBatches: Set[String] = + Set("Push predicate through join by CNF") + val batches = Batch("Subqueries", Once, EliminateSubqueryAliases) :: @@ -39,7 +44,9 @@ class FilterPushdownSuite extends PlanTest { PushPredicateThroughNonJoin, BooleanSimplification, PushPredicateThroughJoin, - CollapseProject) :: Nil + CollapseProject) :: + Batch("Push predicate through join by CNF", Once, + PushPredicateThroughJoinByCNF) :: Nil } val attrA = 'a.int @@ -1230,4 +1237,154 @@ class FilterPushdownSuite extends PlanTest { comparePlans(Optimize.execute(query.analyze), expected) } + + test("inner join: rewrite filter predicates to conjunctive normal form") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y) + .where(("x.b".attr === "y.b".attr) + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11))) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where(('a > 3 || 'a > 1)).subquery('x) + val right = testRelation.where('a > 13 || 'a > 11).subquery('y) + val correctAnswer = + left.join(right, condition = Some("x.b".attr === "y.b".attr + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("inner join: rewrite join predicates to conjunctive normal form") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, condition = Some(("x.b".attr === "y.b".attr) + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('a > 3 || 'a > 1).subquery('x) + val right = testRelation.where('a > 13 || 'a > 11).subquery('y) + val correctAnswer = + left.join(right, condition = Some("x.b".attr === "y.b".attr + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("inner join: rewrite join predicates(with NOT predicate) to conjunctive normal form") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, condition = Some(("x.b".attr === "y.b".attr) + && Not(("x.a".attr > 3) + && ("x.a".attr < 2 || ("y.a".attr > 13)) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('a <= 3 || 'a >= 2).subquery('x) + val right = testRelation.subquery('y) + val correctAnswer = + left.join(right, condition = Some("x.b".attr === "y.b".attr + && (("x.a".attr <= 3) || (("x.a".attr >= 2) && ("y.a".attr <= 13))) + && (("x.a".attr <= 1) || ("y.a".attr <= 11)))) + .analyze + comparePlans(optimized, correctAnswer) + } + + test("left join: rewrite join predicates to conjunctive normal form") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, joinType = LeftOuter, condition = Some(("x.b".attr === "y.b".attr) + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.subquery('x) + val right = testRelation.where('a > 13 || 'a > 11).subquery('y) + val correctAnswer = + left.join(right, joinType = LeftOuter, condition = Some("x.b".attr === "y.b".attr + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("right join: rewrite join predicates to conjunctive normal form") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, joinType = RightOuter, condition = Some(("x.b".attr === "y.b".attr) + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.where('a > 3 || 'a > 1).subquery('x) + val right = testRelation.subquery('y) + val correctAnswer = + left.join(right, joinType = RightOuter, condition = Some("x.b".attr === "y.b".attr + && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("inner join: rewrite to conjunctive normal form avoid generating too many predicates") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, condition = Some(("x.b".attr === "y.b".attr) + && ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5)) + || (("y.a".attr > 2) && ("y.c".attr < 1))))) + } + + val optimized = Optimize.execute(originalQuery.analyze) + val left = testRelation.subquery('x) + val right = testRelation.where('c <= 5 || ('a > 2 && 'c < 1)).subquery('y) + val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr + && ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5)) + || (("y.a".attr > 2) && ("y.c".attr < 1))))).analyze + + comparePlans(optimized, correctAnswer) + } + + test(s"Disable rewrite to CNF by setting ${SQLConf.MAX_REWRITING_CNF_DEPTH.key}=0") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, condition = Some(("x.b".attr === "y.b".attr) + && ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5)) + || (("y.a".attr > 2) && ("y.c".attr < 1))))) + } + + Seq(0, 10).foreach { depth => + withSQLConf(SQLConf.MAX_REWRITING_CNF_DEPTH.key -> depth.toString) { + val optimized = Optimize.execute(originalQuery.analyze) + val (left, right) = if (depth == 0) { + (testRelation.subquery('x), testRelation.subquery('y)) + } else { + (testRelation.subquery('x), + testRelation.where('c <= 5 || ('a > 2 && 'c < 1)).subquery('y)) + } + val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr + && ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5)) + || (("y.a".attr > 2) && ("y.c".attr < 1))))).analyze + + comparePlans(optimized, correctAnswer) + } + } + } }