Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {
if projectList.forall(_.isInstanceOf[Attribute]) =>
reorder(p, p.output)
}

// After reordering is finished, convert OrderedJoin back to Join
result transformDown {
case OrderedJoin(left, right, jt, cond) => Join(left, right, jt, cond)
Expand Down Expand Up @@ -175,11 +176,20 @@ object JoinReorderDP extends PredicateHelper with Logging {
assert(topOutputSet == p.outputSet)
// Keep the same order of final output attributes.
p.copy(projectList = output)
case finalPlan if !sameOutput(finalPlan, output) =>
Project(output, finalPlan)
case finalPlan =>
finalPlan
}
}

private def sameOutput(plan: LogicalPlan, expectedOutput: Seq[Attribute]): Boolean = {
val thisOutput = plan.output
thisOutput.length == expectedOutput.length && thisOutput.zip(expectedOutput).forall {
case (a1, a2) => a1.semanticEquals(a2)
}
}

/** Find all possible plans at the next level, based on existing levels. */
private def searchLevel(
existingLevels: Seq[JoinPlanMap],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case ExtractFiltersAndInnerJoins(input, conditions)
case p @ ExtractFiltersAndInnerJoins(input, conditions)
if input.size > 2 && conditions.nonEmpty =>
if (SQLConf.get.starSchemaDetection && !SQLConf.get.cboEnabled) {
val reordered = if (SQLConf.get.starSchemaDetection && !SQLConf.get.cboEnabled) {
val starJoinPlan = StarSchemaDetection.reorderStarJoins(input, conditions)
if (starJoinPlan.nonEmpty) {
val rest = input.filterNot(starJoinPlan.contains(_))
Expand All @@ -99,6 +99,14 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
} else {
createOrderedJoin(input, conditions)
}

if (p.sameOutput(reordered)) {
reordered
} else {
// Reordering the joins have changed the order of the columns.
// Inject a projection to make sure we restore to the expected ordering.
Project(p.output, reordered)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just for curiosity, we only need this for top-level join? I feel it's ok to change the columnar order for intermedia joins.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right, only the top-level really needs to maintain the appearance. But this is the easiest to implement (the change is local to the rule where order could have changed, so this projection is easier to understand than adding it elsewhere), and it doesn't affect the final result because other optimizer rules are actually going to get rid of the extra intermediate projections.

e.g. if on top of the df, we do an extra operation:

df.groupBy('a, 'b).agg(first('i), first('j), first('x), first('y))

you're going to see that the extra Project gets optimized away in:

=== Result of Batch Operator Optimization before Inferring Filters ===
 Aggregate [a#65, b#66], [a#65, b#66, first(i#63, false) AS first(i, false)#121, first(j#64, false) AS first(j, false)#122, first(x#61, false) AS first(x, false)#123, first(y#62, false) AS first(y, false)#124]   Aggregate [a#65, b#66], [a#65, b#66, first(i#63, false) AS first(i, false)#121, first(j#64, false) AS first(j, false)#122, first(x#61, false) AS first(x, false)#123, first(y#62, false) AS first(y, false)#124]
!+- Project [x#61, y#62, i#63, j#64, a#65, b#66]                                                                                                                                                                    +- Join Cross, (b#66 = i#63)
!   +- Join Inner, ((a#65 = x#61) && (b#66 = i#63))                                                                                                                                                                    :- Join Inner, (a#65 = x#61)
!      :- Project [x#61, y#62, i#63, j#64]                                                                                                                                                                             :  :- Relation[x#61,y#62] parquet
!      :  +- Join Cross                                                                                                                                                                                                :  +- Relation[a#65,b#66] parquet
!      :     :- Relation[x#61,y#62] parquet                                                                                                                                                                            +- Relation[i#63,j#64] parquet
!      :     +- Relation[i#63,j#64] parquet                                                                                                                                                                         
!      +- Relation[a#65,b#66] parquet

(a few other rules may also remove the extra Project)

}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,19 @@ class JoinOptimizationSuite extends PlanTest {
x.join(y).join(z).where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr)),
x.join(z, condition = Some("x.b".attr === "z.b".attr))
.join(y, condition = Some("y.d".attr === "z.a".attr))
.select(Seq("x.a", "x.b", "x.c", "y.d", "z.a", "z.b", "z.c").map(_.attr): _*)
),
(
x.join(y, Cross).join(z, Cross)
.where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr)),
x.join(z, Cross, Some("x.b".attr === "z.b".attr))
.join(y, Cross, Some("y.d".attr === "z.a".attr))
.select(Seq("x.a", "x.b", "x.c", "y.d", "z.a", "z.b", "z.c").map(_.attr): _*)
),
(
x.join(y, Inner).join(z, Cross).where("x.b".attr === "z.a".attr),
x.join(z, Cross, Some("x.b".attr === "z.a".attr)).join(y, Inner)
.select(Seq("x.a", "x.b", "x.c", "y.d", "z.a", "z.b", "z.c").map(_.attr): _*)
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.{Cross, Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan}
import org.apache.spark.sql.internal.SQLConf.{CBO_ENABLED, JOIN_REORDER_ENABLED}
Expand Down Expand Up @@ -124,7 +124,8 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
// the original order (t1 J t2) J t3.
val bestPlan =
t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
.select(outputsOf(t1, t2, t3): _*)

assertEqualPlans(originalPlan, bestPlan)
}
Expand All @@ -139,7 +140,9 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
val bestPlan =
t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
.select(outputsOf(t1, t2, t3): _*) // this is redundant but we'll take it for now
.join(t4)
.select(outputsOf(t1, t2, t4, t3): _*)

assertEqualPlans(originalPlan, bestPlan)
}
Expand Down Expand Up @@ -202,6 +205,7 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
.join(t4.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))),
Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
.select(outputsOf(t1, t4, t2, t3): _*)

assertEqualPlans(originalPlan, bestPlan)
}
Expand All @@ -219,6 +223,23 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
}
}

test("SPARK-26352: join reordering should not change the order of attributes") {
// This test case does not rely on CBO.
// It's similar to the test case above, but catches a reordering bug that the one above doesn't
val tab1 = LocalRelation('x.int, 'y.int)
val tab2 = LocalRelation('i.int, 'j.int)
val tab3 = LocalRelation('a.int, 'b.int)
val original =
tab1.join(tab2, Cross)
.join(tab3, Inner, Some('a === 'x && 'b === 'i))
val expected =
tab1.join(tab3, Inner, Some('a === 'x))
.join(tab2, Cross, Some('b === 'i))
.select(outputsOf(tab1, tab2, tab3): _*)

assertEqualPlans(original, expected)
}

test("reorder recursively") {
// Original order:
// Join
Expand Down Expand Up @@ -266,8 +287,17 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
private def assertEqualPlans(
originalPlan: LogicalPlan,
groundTruthBestPlan: LogicalPlan): Unit = {
val optimized = Optimize.execute(originalPlan.analyze)
val analyzed = originalPlan.analyze
val optimized = Optimize.execute(analyzed)
val expected = groundTruthBestPlan.analyze

assert(analyzed.sameOutput(expected)) // if this fails, the expected plan itself is incorrect
assert(analyzed.sameOutput(optimized))

compareJoinOrder(optimized, expected)
}

private def outputsOf(plans: LogicalPlan*): Seq[Attribute] = {
plans.map(_.output).reduce(_ ++ _)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas
.join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk")))
.join(t2, Inner, Some(nameToAttr("f1_c2") === nameToAttr("t2_c1")))
.join(t1, Inner, Some(nameToAttr("f1_c1") === nameToAttr("t1_c1")))
.select(outputsOf(f1, t1, t2, d1, d2): _*)

assertEqualPlans(query, expected)
}
Expand Down Expand Up @@ -256,6 +257,7 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas
.join(t3.join(t2, Inner, Some(nameToAttr("t2_c2") === nameToAttr("t3_c1"))), Inner,
Some(nameToAttr("d1_c2") === nameToAttr("t2_c1")))
.join(t1, Inner, Some(nameToAttr("t1_c1") === nameToAttr("f1_c1")))
.select(outputsOf(d1, t1, t2, f1, d2, t3): _*)

assertEqualPlans(query, expected)
}
Expand Down Expand Up @@ -297,6 +299,7 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas
Some(nameToAttr("t3_c1") === nameToAttr("t4_c1")))
.join(t1.join(t2, Inner, Some(nameToAttr("t1_c1") === nameToAttr("t2_c1"))), Inner,
Some(nameToAttr("t1_c2") === nameToAttr("t4_c2")))
.select(outputsOf(d1, t1, t2, t3, t4, f1, d2): _*)

assertEqualPlans(query, expected)
}
Expand Down Expand Up @@ -347,6 +350,7 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas
Some(nameToAttr("d3_c2") === nameToAttr("t1_c1")))
.join(t5.join(t6, Inner, Some(nameToAttr("t5_c2") === nameToAttr("t6_c2"))), Inner,
Some(nameToAttr("d2_c2") === nameToAttr("t5_c1")))
.select(outputsOf(d1, t3, t4, f1, d2, t5, t6, d3, t1, t2): _*)

assertEqualPlans(query, expected)
}
Expand Down Expand Up @@ -375,6 +379,7 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas
f1.join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk")))
.join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk")))
.join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk")))
.select(outputsOf(d1, d2, f1, d3): _*)

assertEqualPlans(query, expected)
}
Expand All @@ -400,13 +405,27 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas
f1.join(t3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("t3_c1")))
.join(t2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("t2_c1")))
.join(t1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("t1_c1")))
.select(outputsOf(t1, f1, t2, t3): _*)

assertEqualPlans(query, expected)
}

private def assertEqualPlans( plan1: LogicalPlan, plan2: LogicalPlan): Unit = {
val optimized = Optimize.execute(plan1.analyze)
val analyzed = plan1.analyze
val optimized = Optimize.execute(analyzed)
val expected = plan2.analyze

assert(equivalentOutput(analyzed, expected)) // if this fails, the expected itself is incorrect
assert(equivalentOutput(analyzed, optimized))

compareJoinOrder(optimized, expected)
}

private def outputsOf(plans: LogicalPlan*): Seq[Attribute] = {
plans.map(_.output).reduce(_ ++ _)
}

private def equivalentOutput(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = {
normalizeExprIds(plan1).output == normalizeExprIds(plan2).output
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1")))
.join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1")))
.join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
.select(outputsOf(d1, d2, f1, d3, s3): _*)

assertEqualPlans(query, expected)
}
Expand Down Expand Up @@ -220,6 +221,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1")))
.join(d2, Inner, Some(nameToAttr("f1_fk2") < nameToAttr("d2_pk1")))
.join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
.select(outputsOf(d1, f1, d2, s3, d3): _*)

assertEqualPlans(query, expected)
}
Expand Down Expand Up @@ -255,7 +257,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
.join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
.join(s3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("s3_c2")))

.select(outputsOf(d1, f1, d2, s3, d3): _*)

assertEqualPlans(query, expected)
}
Expand Down Expand Up @@ -292,6 +294,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1")))
.join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_c2")))
.join(s3, Inner, Some(nameToAttr("d3_fk1") < nameToAttr("s3_pk1")))
.select(outputsOf(d1, f1, d2, s3, d3): _*)

assertEqualPlans(query, expected)
}
Expand Down Expand Up @@ -395,6 +398,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d2.where(nameToAttr("d2_c2") === 2), Inner,
Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
.join(s3, Inner, Some(nameToAttr("f11_fk1") === nameToAttr("s3_pk1")))
.select(outputsOf(d1, f11, f1, d2, s3): _*)

assertEqualPlans(query, equivQuery)
}
Expand Down Expand Up @@ -430,6 +434,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d2.where(nameToAttr("d2_c2") === 2), Inner,
Some(nameToAttr("f1_fk2") === nameToAttr("d2_c4")))
.join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
.select(outputsOf(d1, d3, f1, d2, s3): _*)

assertEqualPlans(query, expected)
}
Expand Down Expand Up @@ -465,6 +470,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d2.where(nameToAttr("d2_c2") === 2), Inner,
Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
.join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
.select(outputsOf(d1, d3, f1, d2, s3): _*)

assertEqualPlans(query, expected)
}
Expand Down Expand Up @@ -499,6 +505,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d2.where(nameToAttr("d2_c2") === 2),
Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
.join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
.select(outputsOf(d1, d3, f1, d2, s3): _*)

assertEqualPlans(query, expected)
}
Expand Down Expand Up @@ -532,6 +539,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d3, Inner, Some(nameToAttr("f1_fk3") < nameToAttr("d3_pk1")))
.join(d2, Inner, Some(nameToAttr("f1_fk2") < nameToAttr("d2_pk1")))
.join(s3, Inner, Some(nameToAttr("d3_fk1") < nameToAttr("s3_pk1")))
.select(outputsOf(d1, d3, f1, d2, s3): _*)

assertEqualPlans(query, expected)
}
Expand Down Expand Up @@ -565,13 +573,27 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1")))
.join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
.join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
.select(outputsOf(d1, d3, f1, d2, s3): _*)

assertEqualPlans(query, expected)
}

private def assertEqualPlans( plan1: LogicalPlan, plan2: LogicalPlan): Unit = {
val optimized = Optimize.execute(plan1.analyze)
private def assertEqualPlans(plan1: LogicalPlan, plan2: LogicalPlan): Unit = {
val analyzed = plan1.analyze
val optimized = Optimize.execute(analyzed)
val expected = plan2.analyze

assert(equivalentOutput(analyzed, expected)) // if this fails, the expected itself is incorrect
assert(equivalentOutput(analyzed, optimized))

compareJoinOrder(optimized, expected)
}

private def outputsOf(plans: LogicalPlan*): Seq[Attribute] = {
plans.map(_.output).reduce(_ ++ _)
}

private def equivalentOutput(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = {
normalizeExprIds(plan1).output == normalizeExprIds(plan2).output
}
}
14 changes: 14 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -895,4 +895,18 @@ class JoinSuite extends QueryTest with SharedSQLContext {
checkAnswer(res, Row(0, 0, 0))
}
}

test("SPARK-26352: join reordering should not change the order of columns") {
withTable("tab1", "tab2", "tab3") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we reproduce the bug with temp view?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's more prone to be subject to ConvertToLocalRelation, i.e. you can't copy-and-paste the test case into a Spark shell and expect the same behavior. Would you still prefer to use temp view instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's keep it as it is

spark.sql("select 1 as x, 100 as y").write.saveAsTable("tab1")
spark.sql("select 42 as i, 200 as j").write.saveAsTable("tab2")
spark.sql("select 1 as a, 42 as b").write.saveAsTable("tab3")

val df = spark.sql("""
with tmp as (select * from tab1 cross join tab2)
select * from tmp join tab3 on a = x and b = i
""")
checkAnswer(df, Row(1, 100, 42, 200, 1, 42))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scala> spark.sql("with tmp as (select * from tab1 cross join tab2) select * from tmp join tab3 on a = x and b = i").show()
+---+---+---+---+---+---+
|  x|  y|  i|  j|  a|  b|
+---+---+---+---+---+---+
|  1|100| 42|200|  1| 42|
+---+---+---+---+---+---+

This still passes without code changes, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the command lines work but the test case does not work since we turn off ConvertToLocalRelation in the test environment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, before sending this PR I've actually explicitly run both test cases that I added with and without the fix, and verified that without the fix it'd fail the test.

The reason why you "think" it works in your environment is because you used show, which after @maropu 's earlier PR will introduce a Project at the top, which happens to cure the issue.
If you just do a collect you'll see the problem still exists:

scala> df.show
+---+---+---+---+---+---+
|  x|  y|  i|  j|  a|  b|
+---+---+---+---+---+---+
|  1|100| 42|200|  1| 42|
+---+---+---+---+---+---+

scala> df.collect
res5: Array[org.apache.spark.sql.Row] = Array([1,100,1,42,42,200])

}
}
}