diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index b65221c236bf..a2b31278326d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -88,6 +88,22 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { } } + // Extract a list of logical plans to be joined for join-order comparisons. + // Since `ExtractFiltersAndInnerJoins` handles left-deep trees only, this function have + // the same strategy to extract the plan list. + private[optimizer] def extractLeftDeepInnerJoins(plan: LogicalPlan) + : Seq[LogicalPlan] = plan match { + case Join(left, right, _: InnerLike, _, hint) if hint == JoinHint.NONE => + right +: extractLeftDeepInnerJoins(left) + case Filter(_, child) => extractLeftDeepInnerJoins(child) + case Project(_, child) => extractLeftDeepInnerJoins(child) + case _ => Seq(plan) + } + + private def sameJoinOrder(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { + extractLeftDeepInnerJoins(plan1) == extractLeftDeepInnerJoins(plan2) + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case p @ ExtractFiltersAndInnerJoins(input, conditions) if input.size > 2 && conditions.nonEmpty => @@ -103,12 +119,18 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { createOrderedJoin(input, conditions) } - if (p.sameOutput(reordered)) { - reordered + // To avoid applying this rule repeatedly, we don't change the plan in case of + // the same join order between `p` and `reordered`. + if (!sameJoinOrder(reordered, p)) { + if (p.sameOutput(reordered)) { + reordered + } else { + // Reordering the joins have changed the order of the columns. + // Inject a projection to make sure we restore to the expected ordering. + Project(p.output, reordered) + } } else { - // Reordering the joins have changed the order of the columns. - // Inject a projection to make sure we restore to the expected ordering. - Project(p.output, reordered) + p } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 415ce4678811..e566938f2893 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -222,7 +222,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { } /** - * A pattern that collects the filter and inner joins. + * A pattern that collects the filter and inner joins and skip projections with attributes only. * * Filter * | @@ -230,6 +230,8 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { * / \ ----> (Seq(plan0, plan1, plan2), conditions) * Filter plan2 * | + * Project + * | * inner join * / \ * plan0 plan1 @@ -250,22 +252,23 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper { val (plans, conditions) = flattenJoin(left, joinType) (plans ++ Seq((right, joinType)), conditions ++ cond.toSeq.flatMap(splitConjunctivePredicates)) - case Filter(filterCondition, j @ Join(_, _, _: InnerLike, _, hint)) if hint == JoinHint.NONE => - val (plans, conditions) = flattenJoin(j) + case Filter(filterCondition, child) => + val (plans, conditions) = flattenJoin(child) (plans, conditions ++ splitConjunctivePredicates(filterCondition)) - + case p @ Project(_, child) + // Keep flattening joins when the project has attributes only + if p.projectList.forall(_.isInstanceOf[Attribute]) => + flattenJoin(child) case _ => (Seq((plan, parentJoinType)), Seq.empty) } - def unapply(plan: LogicalPlan) - : Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])] - = plan match { - case f @ Filter(filterCondition, j @ Join(_, _, joinType: InnerLike, _, hint)) - if hint == JoinHint.NONE => - Some(flattenJoin(f)) - case j @ Join(_, _, joinType, _, hint) if hint == JoinHint.NONE => - Some(flattenJoin(j)) - case _ => None + def unapply(plan: LogicalPlan): Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])] = { + val (plans, conditions) = flattenJoin(plan) + if (plans.size > 1) { + Some((plans, conditions)) + } else { + None + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index 3d81c567eff1..d085d8dee204 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -20,11 +20,13 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Expression} import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins import org.apache.spark.sql.catalyst.plans.{Cross, Inner, InnerLike, PlanTest} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.statsEstimation.StatsTestPlan +import org.apache.spark.sql.internal.SQLConf class JoinOptimizationSuite extends PlanTest { @@ -47,6 +49,20 @@ class JoinOptimizationSuite extends PlanTest { val testRelation = LocalRelation('a.int, 'b.int, 'c.int) val testRelation1 = LocalRelation('d.int) + private def testExtractCheckCross( + plan: LogicalPlan, + expected: Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])]): Unit = { + ExtractFiltersAndInnerJoins.unapply(plan) match { + case Some((input, conditions)) => + expected.map { case (expectedPlans, expectedConditions) => + assert(expectedPlans === input) + assert(expectedConditions.toSet === conditions.toSet) + } + case None => + assert(expected.isEmpty) + } + } + test("extract filters and joins") { val x = testRelation.subquery('x) val y = testRelation1.subquery('y) @@ -64,12 +80,6 @@ class JoinOptimizationSuite extends PlanTest { testExtractCheckCross(plan, expectedNoCross) } - def testExtractCheckCross(plan: LogicalPlan, - expected: Option[(Seq[(LogicalPlan, InnerLike)], Seq[Expression])]): Unit = { - assert( - ExtractFiltersAndInnerJoins.unapply(plan) === expected.map(e => (e._1, e._2))) - } - testExtract(x, None) testExtract(x.where("x.b".attr === 1), None) testExtract(x.join(y), Some((Seq(x, y), Seq()))) @@ -126,4 +136,80 @@ class JoinOptimizationSuite extends PlanTest { comparePlans(optimized, queryAnswerPair._2.analyze) } } + + test("Skip projections when flattening joins") { + def checkExtractInnerJoins(plan: LogicalPlan): Unit = { + val expectedTables = plan.collectLeaves().map { case p => (p, Inner) } + val expectedConditions = plan.collect { + case Join(_, _, _, Some(cond), _) => cond + case Filter(cond, _) => cond + } + testExtractCheckCross(plan, Some((expectedTables, expectedConditions))) + } + + val x = testRelation.subquery('x) + val y = testRelation1.subquery('y) + val z = testRelation.subquery('z) + var joined = x.join(z, Inner, Some($"x.b" === $"z.b")) + .select($"x.a", $"z.a", $"z.c") + .join(y, Inner, Some($"y.d" === $"z.a")).analyze + checkExtractInnerJoins(joined) + + // test case for project-over-filter + joined = x.join(z, Inner, Some($"x.b" === $"z.b")) + .select($"x.a", $"z.a", $"z.c") + .where($"y.d" === 3) + .join(y, Inner, Some($"y.d" === $"z.a")).analyze + checkExtractInnerJoins(joined) + + // test case for filter-over-project + joined = x.join(z, Inner, Some($"x.b" === $"z.b")) + .where($"z.a" === 1) + .select($"x.a", $"z.a", $"z.c") + .join(y, Inner, Some($"y.d" === $"z.a")).analyze + checkExtractInnerJoins(joined) + } + + test("Reorder joins with projections") { + withSQLConf( + SQLConf.STARSCHEMA_DETECTION.key -> "true", + SQLConf.CBO_ENABLED.key -> "false") { + val r0output = Seq('a.int, 'b.int, 'c.int) + val r0colStat = ColumnStat(distinctCount = Some(100000000), nullCount = Some(0)) + val r0colStats = AttributeMap(r0output.map(_ -> r0colStat)) + val r0 = StatsTestPlan(r0output, 100000000, r0colStats, identifier = Some("r0")).subquery('r0) + + val r1output = Seq('a.int, 'd.int) + val r1colStat = ColumnStat(distinctCount = Some(10), nullCount = Some(0)) + val r1colStats = AttributeMap(r1output.map(_ -> r1colStat)) + val r1 = StatsTestPlan(r1output, 10, r1colStats, identifier = Some("r1")).subquery('r1) + + val r2output = Seq('b.int, 'e.int) + val r2colStat = ColumnStat(distinctCount = Some(100), nullCount = Some(0)) + val r2colStats = AttributeMap(r2output.map(_ -> r2colStat)) + val r2 = StatsTestPlan(r2output, 100, r2colStats, identifier = Some("r2")).subquery('r2) + + val r3output = Seq('c.int, 'f.int) + val r3colStat = ColumnStat(distinctCount = Some(1), nullCount = Some(0)) + val r3colStats = AttributeMap(r3output.map(_ -> r3colStat)) + val r3 = StatsTestPlan(r3output, 1, r3colStats, identifier = Some("r3")).subquery('r3) + + val joined = r0.join(r1, Inner, Some($"r0.a" === $"r1.a")) + .select($"r0.b", $"r0.c", $"r1.d") + .where($"r1.d" >= 3) + .join(r2, Inner, Some($"r0.b" === $"r2.b")) + .where($"r2.e" >= 5) + .select($"r0.c", $"r1.d", $"r2.e") + .join(r3, Inner, Some($"r0.c" === $"r3.c")) + .select($"r1.d", $"r2.e", $"r3.f") + .where($"r3.f" <= 100) + .analyze + + val optimized = Optimize.execute(joined) + val optJoins = ReorderJoin.extractLeftDeepInnerJoins(optimized) + val joinOrder = optJoins.flatMap(_.collect{ case p: StatsTestPlan => p }.headOption) + .flatMap(_.identifier) + assert(joinOrder === Seq("r2", "r1", "r3", "r0")) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala index 9dceca59f5b8..de743147c9dd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -69,7 +69,8 @@ case class StatsTestPlan( outputList: Seq[Attribute], rowCount: BigInt, attributeStats: AttributeMap[ColumnStat], - size: Option[BigInt] = None) extends LeafNode { + size: Option[BigInt] = None, + identifier: Option[String] = None) extends LeafNode { override def output: Seq[Attribute] = outputList override def computeStats(): Statistics = Statistics( // If sizeInBytes is useless in testing, we just use a fake value