diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index adff8b7451a94..01634a9d852c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -344,11 +344,14 @@ object JoinReorderDP extends PredicateHelper with Logging { } def betterThan(other: JoinPlan, conf: SQLConf): Boolean = { - val thisCost = BigDecimal(this.planCost.card) * conf.joinReorderCardWeight + - BigDecimal(this.planCost.size) * (1 - conf.joinReorderCardWeight) - val otherCost = BigDecimal(other.planCost.card) * conf.joinReorderCardWeight + - BigDecimal(other.planCost.size) * (1 - conf.joinReorderCardWeight) - thisCost < otherCost + if (other.planCost.card == 0 || other.planCost.size == 0) { + false + } else { + val relativeRows = BigDecimal(this.planCost.card) / BigDecimal(other.planCost.card) + val relativeSize = BigDecimal(this.planCost.size) / BigDecimal(other.planCost.size) + relativeRows * conf.joinReorderCardWeight + + relativeSize * (1 - conf.joinReorderCardWeight) < 1 + } } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala index b733bcd4de65e..38a70f0691dd4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} -import org.apache.spark.sql.catalyst.optimizer.JoinReorderDP.JoinPlan import org.apache.spark.sql.catalyst.plans.{Cross, Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -316,18 +315,4 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { case (a1, a2) => a1.semanticEquals(a2) } } - - test("SPARK-33935: betterThan should be consistent") { - val plan1 = JoinPlan(null, null, null, Cost(300, 80)) - val plan2 = JoinPlan(null, null, null, Cost(500, 30)) - - // cost1 = 300*0.7 + 80*0.3 = 234 - // cost2 = 500*0.7 + 30*0.3 = 359 - - assert(!plan1.betterThan(plan1, conf)) - assert(!plan2.betterThan(plan2, conf)) - - assert(plan1.betterThan(plan2, conf)) - assert(!plan2.betterThan(plan1, conf)) - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala index 24abd8ed344bf..baae934e1e4fe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala @@ -293,12 +293,12 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas (nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) val expected = - t3.join(t4, Inner, Some(nameToAttr("t3_c1") === nameToAttr("t4_c1"))) + f1.join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk"))) + .join(t3.join(t4, Inner, Some(nameToAttr("t3_c1") === nameToAttr("t4_c1"))), Inner, + Some(nameToAttr("t3_c1") === nameToAttr("t4_c1"))) .join(t1.join(t2, Inner, Some(nameToAttr("t1_c1") === nameToAttr("t2_c1"))), Inner, Some(nameToAttr("t1_c2") === nameToAttr("t4_c2"))) - .join(f1 - .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) - .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk")))) .select(outputsOf(d1, t1, t2, t3, t4, f1, d2): _*) assertEqualPlans(query, expected)