Skip to content

Commit 1c781a4

Browse files
imback82cloud-fan
authored andcommitted
[SPARK-32282][SQL] Improve EnsureRquirement.reorderJoinKeys to handle more scenarios such as PartitioningCollection
### What changes were proposed in this pull request? This PR proposes to improve `EnsureRquirement.reorderJoinKeys` to handle the following scenarios: 1. If the keys cannot be reordered to match the left-side `HashPartitioning`, consider the right-side `HashPartitioning`. 2. Handle `PartitioningCollection`, which may contain `HashPartitioning` ### Why are the changes needed? 1. For the scenario 1), the current behavior matches either the left-side `HashPartitioning` or the right-side `HashPartitioning`. This means that if both sides are `HashPartitioning`, it will try to match only the left side. The following will not consider the right-side `HashPartitioning`: ``` val df1 = (0 until 10).map(i => (i % 5, i % 13)).toDF("i1", "j1") val df2 = (0 until 10).map(i => (i % 7, i % 11)).toDF("i2", "j2") df1.write.format("parquet").bucketBy(4, "i1", "j1").saveAsTable("t1")df2.write.format("parquet").bucketBy(4, "i2", "j2").saveAsTable("t2") val t1 = spark.table("t1") val t2 = spark.table("t2") val join = t1.join(t2, t1("i1") === t2("j2") && t1("i1") === t2("i2")) join.explain == Physical Plan == *(5) SortMergeJoin [i1#26, i1#26], [j2#31, i2#30], Inner :- *(2) Sort [i1#26 ASC NULLS FIRST, i1#26 ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(i1#26, i1#26, 4), true, [id=#69] : +- *(1) Project [i1#26, j1#27] : +- *(1) Filter isnotnull(i1#26) : +- *(1) ColumnarToRow : +- FileScan parquet default.t1[i1#26,j1#27] Batched: true, DataFilters: [isnotnull(i1#26)], Format: Parquet, Location: InMemoryFileIndex[..., PartitionFilters: [], PushedFilters: [IsNotNull(i1)], ReadSchema: struct<i1:int,j1:int>, SelectedBucketsCount: 4 out of 4 +- *(4) Sort [j2#31 ASC NULLS FIRST, i2#30 ASC NULLS FIRST], false, 0. +- Exchange hashpartitioning(j2#31, i2#30, 4), true, [id=#79]. <===== This can be removed +- *(3) Project [i2#30, j2#31] +- *(3) Filter (((j2#31 = i2#30) AND isnotnull(j2#31)) AND isnotnull(i2#30)) +- *(3) ColumnarToRow +- FileScan parquet default.t2[i2#30,j2#31] Batched: true, DataFilters: [(j2#31 = i2#30), isnotnull(j2#31), isnotnull(i2#30)], Format: Parquet, Location: InMemoryFileIndex[..., PartitionFilters: [], PushedFilters: [IsNotNull(j2), IsNotNull(i2)], ReadSchema: struct<i2:int,j2:int>, SelectedBucketsCount: 4 out of 4 ``` 2. For the scenario 2), the current behavior does not handle `PartitioningCollection`: ``` val df1 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i1", "j1") val df2 = (0 until 100).map(i => (i % 7, i % 11)).toDF("i2", "j2") val df3 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i3", "j3") val join = df1.join(df2, df1("i1") === df2("i2") && df1("j1") === df2("j2")) // PartitioningCollection val join2 = join.join(df3, join("j1") === df3("j3") && join("i1") === df3("i3")) join2.explain == Physical Plan == *(9) SortMergeJoin [j1#8, i1#7], [j3#30, i3#29], Inner :- *(6) Sort [j1#8 ASC NULLS FIRST, i1#7 ASC NULLS FIRST], false, 0. <===== This can be removed : +- Exchange hashpartitioning(j1#8, i1#7, 5), true, [id=#58] <===== This can be removed : +- *(5) SortMergeJoin [i1#7, j1#8], [i2#18, j2#19], Inner : :- *(2) Sort [i1#7 ASC NULLS FIRST, j1#8 ASC NULLS FIRST], false, 0 : : +- Exchange hashpartitioning(i1#7, j1#8, 5), true, [id=#45] : : +- *(1) Project [_1#2 AS i1#7, _2#3 AS j1#8] : : +- *(1) LocalTableScan [_1#2, _2#3] : +- *(4) Sort [i2#18 ASC NULLS FIRST, j2#19 ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(i2#18, j2#19, 5), true, [id=#51] : +- *(3) Project [_1#13 AS i2#18, _2#14 AS j2#19] : +- *(3) LocalTableScan [_1#13, _2#14] +- *(8) Sort [j3#30 ASC NULLS FIRST, i3#29 ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(j3#30, i3#29, 5), true, [id=#64] +- *(7) Project [_1#24 AS i3#29, _2#25 AS j3#30] +- *(7) LocalTableScan [_1#24, _2#25] ``` ### Does this PR introduce _any_ user-facing change? Yes, now from the above examples, the shuffle/sort nodes pointed by `This can be removed` are now removed: 1. Senario 1): ``` == Physical Plan == *(4) SortMergeJoin [i1#26, i1#26], [i2#30, j2#31], Inner :- *(2) Sort [i1#26 ASC NULLS FIRST, i1#26 ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(i1#26, i1#26, 4), true, [id=#67] : +- *(1) Project [i1#26, j1#27] : +- *(1) Filter isnotnull(i1#26) : +- *(1) ColumnarToRow : +- FileScan parquet default.t1[i1#26,j1#27] Batched: true, DataFilters: [isnotnull(i1#26)], Format: Parquet, Location: InMemoryFileIndex[..., PartitionFilters: [], PushedFilters: [IsNotNull(i1)], ReadSchema: struct<i1:int,j1:int>, SelectedBucketsCount: 4 out of 4 +- *(3) Sort [i2#30 ASC NULLS FIRST, j2#31 ASC NULLS FIRST], false, 0 +- *(3) Project [i2#30, j2#31] +- *(3) Filter (((j2#31 = i2#30) AND isnotnull(j2#31)) AND isnotnull(i2#30)) +- *(3) ColumnarToRow +- FileScan parquet default.t2[i2#30,j2#31] Batched: true, DataFilters: [(j2#31 = i2#30), isnotnull(j2#31), isnotnull(i2#30)], Format: Parquet, Location: InMemoryFileIndex[..., PartitionFilters: [], PushedFilters: [IsNotNull(j2), IsNotNull(i2)], ReadSchema: struct<i2:int,j2:int>, SelectedBucketsCount: 4 out of 4 ``` 2. Scenario 2): ``` == Physical Plan == *(8) SortMergeJoin [i1#7, j1#8], [i3#29, j3#30], Inner :- *(5) SortMergeJoin [i1#7, j1#8], [i2#18, j2#19], Inner : :- *(2) Sort [i1#7 ASC NULLS FIRST, j1#8 ASC NULLS FIRST], false, 0 : : +- Exchange hashpartitioning(i1#7, j1#8, 5), true, [id=#43] : : +- *(1) Project [_1#2 AS i1#7, _2#3 AS j1#8] : : +- *(1) LocalTableScan [_1#2, _2#3] : +- *(4) Sort [i2#18 ASC NULLS FIRST, j2#19 ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(i2#18, j2#19, 5), true, [id=#49] : +- *(3) Project [_1#13 AS i2#18, _2#14 AS j2#19] : +- *(3) LocalTableScan [_1#13, _2#14] +- *(7) Sort [i3#29 ASC NULLS FIRST, j3#30 ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(i3#29, j3#30, 5), true, [id=#58] +- *(6) Project [_1#24 AS i3#29, _2#25 AS j3#30] +- *(6) LocalTableScan [_1#24, _2#25] ``` ### How was this patch tested? Added tests. Closes #29074 from imback82/reorder_keys. Authored-by: Terry Kim <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent bbc887b commit 1c781a4

File tree

2 files changed

+168
-12
lines changed

2 files changed

+168
-12
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,14 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
135135
leftKeys: IndexedSeq[Expression],
136136
rightKeys: IndexedSeq[Expression],
137137
expectedOrderOfKeys: Seq[Expression],
138-
currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
138+
currentOrderOfKeys: Seq[Expression]): Option[(Seq[Expression], Seq[Expression])] = {
139139
if (expectedOrderOfKeys.size != currentOrderOfKeys.size) {
140-
return (leftKeys, rightKeys)
140+
return None
141+
}
142+
143+
// Check if the current order already satisfies the expected order.
144+
if (expectedOrderOfKeys.zip(currentOrderOfKeys).forall(p => p._1.semanticEquals(p._2))) {
145+
return Some(leftKeys, rightKeys)
141146
}
142147

143148
// Build a lookup between an expression and the positions its holds in the current key seq.
@@ -164,10 +169,10 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
164169
rightKeysBuffer += rightKeys(index)
165170
case _ =>
166171
// The expression cannot be found, or we have exhausted all indices for that expression.
167-
return (leftKeys, rightKeys)
172+
return None
168173
}
169174
}
170-
(leftKeysBuffer.toSeq, rightKeysBuffer.toSeq)
175+
Some(leftKeysBuffer.toSeq, rightKeysBuffer.toSeq)
171176
}
172177

173178
private def reorderJoinKeys(
@@ -176,19 +181,48 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
176181
leftPartitioning: Partitioning,
177182
rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {
178183
if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
179-
(leftPartitioning, rightPartitioning) match {
180-
case (HashPartitioning(leftExpressions, _), _) =>
181-
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leftExpressions, leftKeys)
182-
case (_, HashPartitioning(rightExpressions, _)) =>
183-
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys)
184-
case _ =>
185-
(leftKeys, rightKeys)
186-
}
184+
reorderJoinKeysRecursively(
185+
leftKeys,
186+
rightKeys,
187+
Some(leftPartitioning),
188+
Some(rightPartitioning))
189+
.getOrElse((leftKeys, rightKeys))
187190
} else {
188191
(leftKeys, rightKeys)
189192
}
190193
}
191194

195+
/**
196+
* Recursively reorders the join keys based on partitioning. It starts reordering the
197+
* join keys to match HashPartitioning on either side, followed by PartitioningCollection.
198+
*/
199+
private def reorderJoinKeysRecursively(
200+
leftKeys: Seq[Expression],
201+
rightKeys: Seq[Expression],
202+
leftPartitioning: Option[Partitioning],
203+
rightPartitioning: Option[Partitioning]): Option[(Seq[Expression], Seq[Expression])] = {
204+
(leftPartitioning, rightPartitioning) match {
205+
case (Some(HashPartitioning(leftExpressions, _)), _) =>
206+
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leftExpressions, leftKeys)
207+
.orElse(reorderJoinKeysRecursively(
208+
leftKeys, rightKeys, None, rightPartitioning))
209+
case (_, Some(HashPartitioning(rightExpressions, _))) =>
210+
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys)
211+
.orElse(reorderJoinKeysRecursively(
212+
leftKeys, rightKeys, leftPartitioning, None))
213+
case (Some(PartitioningCollection(partitionings)), _) =>
214+
partitionings.foldLeft(Option.empty[(Seq[Expression], Seq[Expression])]) { (res, p) =>
215+
res.orElse(reorderJoinKeysRecursively(leftKeys, rightKeys, Some(p), rightPartitioning))
216+
}.orElse(reorderJoinKeysRecursively(leftKeys, rightKeys, None, rightPartitioning))
217+
case (_, Some(PartitioningCollection(partitionings))) =>
218+
partitionings.foldLeft(Option.empty[(Seq[Expression], Seq[Expression])]) { (res, p) =>
219+
res.orElse(reorderJoinKeysRecursively(leftKeys, rightKeys, leftPartitioning, Some(p)))
220+
}.orElse(None)
221+
case _ =>
222+
None
223+
}
224+
}
225+
192226
/**
193227
* When the physical operators are created for JOIN, the ordering of join keys is based on order
194228
* in which the join keys appear in the user query. That might not match with the output
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.exchange
19+
20+
import org.apache.spark.sql.catalyst.expressions.Literal
21+
import org.apache.spark.sql.catalyst.plans.Inner
22+
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection}
23+
import org.apache.spark.sql.execution.{DummySparkPlan, SortExec}
24+
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
25+
import org.apache.spark.sql.test.SharedSparkSession
26+
27+
class EnsureRequirementsSuite extends SharedSparkSession {
28+
private val exprA = Literal(1)
29+
private val exprB = Literal(2)
30+
private val exprC = Literal(3)
31+
32+
test("reorder should handle PartitioningCollection") {
33+
val plan1 = DummySparkPlan(
34+
outputPartitioning = PartitioningCollection(Seq(
35+
HashPartitioning(exprA :: exprB :: Nil, 5),
36+
HashPartitioning(exprA :: Nil, 5))))
37+
val plan2 = DummySparkPlan()
38+
39+
// Test PartitioningCollection on the left side of join.
40+
val smjExec1 = SortMergeJoinExec(
41+
exprB :: exprA :: Nil, exprA :: exprB :: Nil, Inner, None, plan1, plan2)
42+
EnsureRequirements(spark.sessionState.conf).apply(smjExec1) match {
43+
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
44+
SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _),
45+
SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), _) =>
46+
assert(leftKeys === Seq(exprA, exprB))
47+
assert(rightKeys === Seq(exprB, exprA))
48+
case other => fail(other.toString)
49+
}
50+
51+
// Test PartitioningCollection on the right side of join.
52+
val smjExec2 = SortMergeJoinExec(
53+
exprA :: exprB :: Nil, exprB :: exprA :: Nil, Inner, None, plan2, plan1)
54+
EnsureRequirements(spark.sessionState.conf).apply(smjExec2) match {
55+
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
56+
SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _),
57+
SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), _) =>
58+
assert(leftKeys === Seq(exprB, exprA))
59+
assert(rightKeys === Seq(exprA, exprB))
60+
case other => fail(other.toString)
61+
}
62+
63+
// Both sides are PartitioningCollection, but left side cannot be reorderd to match
64+
// and it should fall back to the right side.
65+
val smjExec3 = SortMergeJoinExec(
66+
exprA :: exprC :: Nil, exprB :: exprA :: Nil, Inner, None, plan1, plan1)
67+
EnsureRequirements(spark.sessionState.conf).apply(smjExec3) match {
68+
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
69+
SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _),
70+
SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), _) =>
71+
assert(leftKeys === Seq(exprC, exprA))
72+
assert(rightKeys === Seq(exprA, exprB))
73+
case other => fail(other.toString)
74+
}
75+
}
76+
77+
test("reorder should fallback to the other side partitioning") {
78+
val plan1 = DummySparkPlan(
79+
outputPartitioning = HashPartitioning(exprA :: exprB :: exprC :: Nil, 5))
80+
val plan2 = DummySparkPlan(
81+
outputPartitioning = HashPartitioning(exprB :: exprC :: Nil, 5))
82+
83+
// Test fallback to the right side, which has HashPartitioning.
84+
val smjExec1 = SortMergeJoinExec(
85+
exprA :: exprB :: Nil, exprC :: exprB :: Nil, Inner, None, plan1, plan2)
86+
EnsureRequirements(spark.sessionState.conf).apply(smjExec1) match {
87+
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
88+
SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _),
89+
SortExec(_, _, DummySparkPlan(_, _, _: HashPartitioning, _, _), _), _) =>
90+
assert(leftKeys === Seq(exprB, exprA))
91+
assert(rightKeys === Seq(exprB, exprC))
92+
case other => fail(other.toString)
93+
}
94+
95+
// Test fallback to the right side, which has PartitioningCollection.
96+
val plan3 = DummySparkPlan(
97+
outputPartitioning = PartitioningCollection(Seq(HashPartitioning(exprB :: exprC :: Nil, 5))))
98+
val smjExec2 = SortMergeJoinExec(
99+
exprA :: exprB :: Nil, exprC :: exprB :: Nil, Inner, None, plan1, plan3)
100+
EnsureRequirements(spark.sessionState.conf).apply(smjExec2) match {
101+
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
102+
SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _),
103+
SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), _) =>
104+
assert(leftKeys === Seq(exprB, exprA))
105+
assert(rightKeys === Seq(exprB, exprC))
106+
case other => fail(other.toString)
107+
}
108+
109+
// The right side has HashPartitioning, so it is matched first, but no reordering match is
110+
// found, and it should fall back to the left side, which has a PartitioningCollection.
111+
val smjExec3 = SortMergeJoinExec(
112+
exprC :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None, plan3, plan1)
113+
EnsureRequirements(spark.sessionState.conf).apply(smjExec3) match {
114+
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
115+
SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _),
116+
SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), _) =>
117+
assert(leftKeys === Seq(exprB, exprC))
118+
assert(rightKeys === Seq(exprB, exprA))
119+
case other => fail(other.toString)
120+
}
121+
}
122+
}

0 commit comments

Comments
 (0)