diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala index f92d8f5b8e534..93c608dc71a24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -344,14 +344,11 @@ object JoinReorderDP extends PredicateHelper with Logging { } def betterThan(other: JoinPlan, conf: SQLConf): Boolean = { - if (other.planCost.card == 0 || other.planCost.size == 0) { - false - } else { - val relativeRows = BigDecimal(this.planCost.card) / BigDecimal(other.planCost.card) - val relativeSize = BigDecimal(this.planCost.size) / BigDecimal(other.planCost.size) - relativeRows * conf.joinReorderCardWeight + - relativeSize * (1 - conf.joinReorderCardWeight) < 1 - } + val thisCost = BigDecimal(this.planCost.card) * conf.joinReorderCardWeight + + BigDecimal(this.planCost.size) * (1 - conf.joinReorderCardWeight) + val otherCost = BigDecimal(other.planCost.card) * conf.joinReorderCardWeight + + BigDecimal(other.planCost.size) * (1 - conf.joinReorderCardWeight) + thisCost < otherCost } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala index 36db2e2dd0ae2..75fe3dd4062a7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.optimizer.JoinReorderDP.JoinPlan import org.apache.spark.sql.catalyst.plans.{Cross, Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -385,4 +386,18 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { private def outputsOf(plans: LogicalPlan*): Seq[Attribute] = { plans.map(_.output).reduce(_ ++ _) } + + test("SPARK-33935: betterThan should be consistent") { + val plan1 = JoinPlan(null, null, null, Cost(300, 80)) + val plan2 = JoinPlan(null, null, null, Cost(500, 30)) + + // cost1 = 300*0.7 + 80*0.3 = 234 + // cost2 = 500*0.7 + 30*0.3 = 359 + + assert(!plan1.betterThan(plan1, conf)) + assert(!plan2.betterThan(plan2, conf)) + + assert(plan1.betterThan(plan2, conf)) + assert(!plan2.betterThan(plan1, conf)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala index 5b8e59ae7cb31..b7cf383732923 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala @@ -293,12 +293,12 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas (nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) val expected = - f1.join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) - .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk"))) - .join(t3.join(t4, Inner, Some(nameToAttr("t3_c1") === nameToAttr("t4_c1"))), Inner, - Some(nameToAttr("t3_c1") === nameToAttr("t4_c1"))) + t3.join(t4, Inner, Some(nameToAttr("t3_c1") === nameToAttr("t4_c1"))) .join(t1.join(t2, Inner, Some(nameToAttr("t1_c1") === nameToAttr("t2_c1"))), Inner, Some(nameToAttr("t1_c2") === nameToAttr("t4_c2"))) + .join(f1 + .join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk"))) + .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk")))) .select(outputsOf(d1, t1, t2, t3, t4, f1, d2): _*) assertEqualPlans(query, expected)