diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index 6540e95b01e3f..f92d8f5b8e534 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -43,10 +43,10 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { val result = plan transformDown { // Start reordering with a joinable item, which is an InnerLike join with conditions. // Avoid reordering if a join hint is present. - case j @ Join(_, _, _: InnerLike, Some(cond), hint) if hint == JoinHint.NONE => + case j @ Join(_, _, _: InnerLike, Some(cond), JoinHint.NONE) => reorder(j, j.output) - case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond), hint)) - if projectList.forall(_.isInstanceOf[Attribute]) && hint == JoinHint.NONE => + case p @ Project(projectList, Join(_, _, _: InnerLike, Some(cond), JoinHint.NONE)) + if projectList.forall(_.isInstanceOf[Attribute]) => reorder(p, p.output) } // After reordering is finished, convert OrderedJoin back to Join. @@ -77,12 +77,12 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { */ private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], Set[Expression]) = { plan match { - case Join(left, right, _: InnerLike, Some(cond), _) => + case Join(left, right, _: InnerLike, Some(cond), JoinHint.NONE) => val (leftPlans, leftConditions) = extractInnerJoins(left) val (rightPlans, rightConditions) = extractInnerJoins(right) (leftPlans ++ rightPlans, splitConjunctivePredicates(cond).toSet ++ leftConditions ++ rightConditions) - case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _)) + case Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), JoinHint.NONE)) if projectList.forall(_.isInstanceOf[Attribute]) => extractInnerJoins(j) case _ => @@ -91,11 +91,11 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper { } private def replaceWithOrderedJoin(plan: LogicalPlan): LogicalPlan = plan match { - case j @ Join(left, right, jt: InnerLike, Some(cond), _) => + case j @ Join(left, right, jt: InnerLike, Some(cond), JoinHint.NONE) => val replacedLeft = replaceWithOrderedJoin(left) val replacedRight = replaceWithOrderedJoin(right) OrderedJoin(replacedLeft, replacedRight, jt, Some(cond)) - case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), _)) => + case p @ Project(projectList, j @ Join(_, _, _: InnerLike, Some(cond), JoinHint.NONE)) => p.copy(child = replaceWithOrderedJoin(j)) case _ => plan diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala index f1da0a8e865b0..18516ee7872a1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -312,6 +312,14 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { .join(t3, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2"))) assertEqualPlans(originalPlan2, originalPlan2) + + val originalPlan3 = + t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .join(t4).hint("broadcast") + .join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .join(t5, Inner, Some(nameToAttr("t5.v-1-5") === nameToAttr("t3.v-1-100"))) + + assertEqualPlans(originalPlan3, originalPlan3) } test("reorder below and above the hint node") { @@ -342,6 +350,23 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { .join(t4.hint("broadcast")) assertEqualPlans(originalPlan2, bestPlan2) + + val originalPlan3 = + t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .hint("broadcast") + .join(t4, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))) + .join(t5, Inner, Some(nameToAttr("t5.v-1-5") === nameToAttr("t3.v-1-100"))) + + val bestPlan3 = + t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .select(outputsOf(t1, t2, t3): _*) + .hint("broadcast") + .join(t4, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))) + .join(t5, Inner, Some(nameToAttr("t5.v-1-5") === nameToAttr("t3.v-1-100"))) + + assertEqualPlans(originalPlan3, bestPlan3) } private def assertEqualPlans( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala index 30a3d54fd833f..67f0f1a6fd23d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinHintSuite.scala @@ -102,58 +102,60 @@ class JoinHintSuite extends PlanTest with SharedSQLContext { } test("hints prevent join reorder") { - withTempView("a", "b", "c") { - df1.createOrReplaceTempView("a") - df2.createOrReplaceTempView("b") - df3.createOrReplaceTempView("c") - verifyJoinHint( - sql("select /*+ broadcast(a, c)*/ * from a, b, c " + - "where a.a1 = b.b1 and b.b1 = c.c1"), - JoinHint( - None, - Some(HintInfo(broadcast = true))) :: - JoinHint( - Some(HintInfo(broadcast = true)), - None):: Nil - ) - verifyJoinHint( - sql("select /*+ broadcast(a, c)*/ * from a, c, b " + - "where a.a1 = b.b1 and b.b1 = c.c1"), - JoinHint.NONE :: + withSQLConf(SQLConf.CBO_ENABLED.key -> "true", SQLConf.JOIN_REORDER_ENABLED.key -> "true") { + withTempView("a", "b", "c") { + df1.createOrReplaceTempView("a") + df2.createOrReplaceTempView("b") + df3.createOrReplaceTempView("c") + verifyJoinHint( + sql("select /*+ broadcast(a, c)*/ * from a, b, c " + + "where a.a1 = b.b1 and b.b1 = c.c1"), JoinHint( - Some(HintInfo(broadcast = true)), - Some(HintInfo(broadcast = true))):: Nil - ) - verifyJoinHint( - sql("select /*+ broadcast(b, c)*/ * from a, c, b " + - "where a.a1 = b.b1 and b.b1 = c.c1"), - JoinHint( - None, - Some(HintInfo(broadcast = true))) :: + None, + Some(HintInfo(broadcast = true))) :: + JoinHint( + Some(HintInfo(broadcast = true)), + None) :: Nil + ) + verifyJoinHint( + sql("select /*+ broadcast(a, c)*/ * from a, c, b " + + "where a.a1 = b.b1 and b.b1 = c.c1"), + JoinHint.NONE :: + JoinHint( + Some(HintInfo(broadcast = true)), + Some(HintInfo(broadcast = true))) :: Nil + ) + verifyJoinHint( + sql("select /*+ broadcast(b, c)*/ * from a, c, b " + + "where a.a1 = b.b1 and b.b1 = c.c1"), JoinHint( None, - Some(HintInfo(broadcast = true))):: Nil - ) - - verifyJoinHint( - df1.join(df2, 'a1 === 'b1 && 'a1 > 5).hint("broadcast") - .join(df3, 'b1 === 'c1 && 'a1 < 10), - JoinHint( - Some(HintInfo(broadcast = true)), - None) :: - JoinHint.NONE:: Nil - ) + Some(HintInfo(broadcast = true))) :: + JoinHint( + None, + Some(HintInfo(broadcast = true))) :: Nil + ) - verifyJoinHint( - df1.join(df2, 'a1 === 'b1 && 'a1 > 5).hint("broadcast") - .join(df3, 'b1 === 'c1 && 'a1 < 10) - .join(df, 'b1 === 'id), - JoinHint.NONE :: + verifyJoinHint( + df1.join(df2, 'a1 === 'b1 && 'a1 > 5).hint("broadcast") + .join(df3, 'b1 === 'c1 && 'a1 < 10), JoinHint( Some(HintInfo(broadcast = true)), None) :: - JoinHint.NONE:: Nil - ) + JoinHint.NONE :: Nil + ) + + verifyJoinHint( + df1.join(df2, 'a1 === 'b1 && 'a1 > 5).hint("broadcast") + .join(df3, 'b1 === 'c1 && 'a1 < 10) + .join(df, 'b1 === 'id), + JoinHint.NONE :: + JoinHint( + Some(HintInfo(broadcast = true)), + None) :: + JoinHint.NONE :: Nil + ) + } } }