Skip to content

Commit e0ecb66

Browse files
imback82cloud-fan
authored andcommitted
[SPARK-31869][SQL] BroadcastHashJoinExec can utilize the build side for its output partitioning
### What changes were proposed in this pull request? Currently, the `BroadcastHashJoinExec`'s `outputPartitioning` only uses the streamed side's `outputPartitioning`. However, if the join type of `BroadcastHashJoinExec` is an inner-like join, the build side's info (the join keys) can be added to `BroadcastHashJoinExec`'s `outputPartitioning`. For example, ```Scala spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "500") val t1 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i1", "j1") val t2 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i2", "j2") val t3 = (0 until 20).map(i => (i % 7, i % 11)).toDF("i3", "j3") val t4 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i4", "j4") // join1 is a sort merge join. val join1 = t1.join(t2, t1("i1") === t2("i2")) // join2 is a broadcast join where t3 is broadcasted. val join2 = join1.join(t3, join1("i1") === t3("i3")) // Join on the column from the broadcasted side (i3). val join3 = join2.join(t4, join2("i3") === t4("i4")) join3.explain ``` You see that `Exchange hashpartitioning(i2#103, 200)` is introduced because there is no output partitioning info from the build side. ``` == Physical Plan == *(6) SortMergeJoin [i3#29], [i4#40], Inner :- *(4) Sort [i3#29 ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(i3#29, 200), true, [id=#55] : +- *(3) BroadcastHashJoin [i1#7], [i3#29], Inner, BuildRight : :- *(3) SortMergeJoin [i1#7], [i2#18], Inner : : :- *(1) Sort [i1#7 ASC NULLS FIRST], false, 0 : : : +- Exchange hashpartitioning(i1#7, 200), true, [id=#28] : : : +- LocalTableScan [i1#7, j1#8] : : +- *(2) Sort [i2#18 ASC NULLS FIRST], false, 0 : : +- Exchange hashpartitioning(i2#18, 200), true, [id=#29] : : +- LocalTableScan [i2#18, j2#19] : +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint))), [id=#34] : +- LocalTableScan [i3#29, j3#30] +- *(5) Sort [i4#40 ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(i4#40, 200), true, [id=#39] +- LocalTableScan [i4#40, j4#41] ``` This PR proposes to introduce output partitioning for the build side for `BroadcastHashJoinExec` if the streamed side has a `HashPartitioning` or a collection of `HashPartitioning`s. There is a new internal config `spark.sql.execution.broadcastHashJoin.outputPartitioningExpandLimit`, which can limit the number of partitioning a `HashPartitioning` can expand to. It can be set to "0" to disable this feature. ### Why are the changes needed? To remove unnecessary shuffle. ### Does this PR introduce _any_ user-facing change? Yes, now the shuffle in the above example can be eliminated: ``` == Physical Plan == *(5) SortMergeJoin [i3#108], [i4#119], Inner :- *(3) Sort [i3#108 ASC NULLS FIRST], false, 0 : +- *(3) BroadcastHashJoin [i1#86], [i3#108], Inner, BuildRight : :- *(3) SortMergeJoin [i1#86], [i2#97], Inner : : :- *(1) Sort [i1#86 ASC NULLS FIRST], false, 0 : : : +- Exchange hashpartitioning(i1#86, 200), true, [id=#120] : : : +- LocalTableScan [i1#86, j1#87] : : +- *(2) Sort [i2#97 ASC NULLS FIRST], false, 0 : : +- Exchange hashpartitioning(i2#97, 200), true, [id=#121] : : +- LocalTableScan [i2#97, j2#98] : +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, false] as bigint))), [id=#126] : +- LocalTableScan [i3#108, j3#109] +- *(4) Sort [i4#119 ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(i4#119, 200), true, [id=#130] +- LocalTableScan [i4#119, j4#120] ``` ### How was this patch tested? Added new tests. Closes #28676 from imback82/broadcast_join_output. Authored-by: Terry Kim <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent d0c83f3 commit e0ecb66

File tree

6 files changed

+322
-16
lines changed

6 files changed

+322
-16
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2658,6 +2658,17 @@ object SQLConf {
26582658
.checkValue(_ > 0, "The difference must be positive.")
26592659
.createWithDefault(4)
26602660

2661+
val BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT =
2662+
buildConf("spark.sql.execution.broadcastHashJoin.outputPartitioningExpandLimit")
2663+
.internal()
2664+
.doc("The maximum number of partitionings that a HashPartitioning can be expanded to. " +
2665+
"This configuration is applicable only for BroadcastHashJoin inner joins and can be " +
2666+
"set to '0' to disable this feature.")
2667+
.version("3.1.0")
2668+
.intConf
2669+
.checkValue(_ >= 0, "The value must be non-negative.")
2670+
.createWithDefault(8)
2671+
26612672
/**
26622673
* Holds information about keys that have been deprecated.
26632674
*
@@ -2966,6 +2977,9 @@ class SQLConf extends Serializable with Logging {
29662977
LegacyBehaviorPolicy.withName(getConf(SQLConf.LEGACY_TIME_PARSER_POLICY))
29672978
}
29682979

2980+
def broadcastHashJoinOutputPartitioningExpandLimit: Int =
2981+
getConf(BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT)
2982+
29692983
/**
29702984
* Returns the [[Resolver]] for the current configuration, which can be used to determine if two
29712985
* identifiers are equal.

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.execution.joins
1919

20+
import scala.collection.mutable
21+
2022
import org.apache.spark.TaskContext
2123
import org.apache.spark.broadcast.Broadcast
2224
import org.apache.spark.rdd.RDD
@@ -26,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
2628
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2729
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
2830
import org.apache.spark.sql.catalyst.plans._
29-
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution}
31+
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, HashPartitioning, Partitioning, PartitioningCollection, UnspecifiedDistribution}
3032
import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan}
3133
import org.apache.spark.sql.execution.metric.SQLMetrics
3234
import org.apache.spark.sql.types.{BooleanType, LongType}
@@ -51,7 +53,7 @@ case class BroadcastHashJoinExec(
5153
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
5254

5355
override def requiredChildDistribution: Seq[Distribution] = {
54-
val mode = HashedRelationBroadcastMode(buildKeys)
56+
val mode = HashedRelationBroadcastMode(buildBoundKeys)
5557
buildSide match {
5658
case BuildLeft =>
5759
BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil
@@ -60,6 +62,73 @@ case class BroadcastHashJoinExec(
6062
}
6163
}
6264

65+
override lazy val outputPartitioning: Partitioning = {
66+
joinType match {
67+
case _: InnerLike if sqlContext.conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 =>
68+
streamedPlan.outputPartitioning match {
69+
case h: HashPartitioning => expandOutputPartitioning(h)
70+
case c: PartitioningCollection => expandOutputPartitioning(c)
71+
case other => other
72+
}
73+
case _ => streamedPlan.outputPartitioning
74+
}
75+
}
76+
77+
// An one-to-many mapping from a streamed key to build keys.
78+
private lazy val streamedKeyToBuildKeyMapping = {
79+
val mapping = mutable.Map.empty[Expression, Seq[Expression]]
80+
streamedKeys.zip(buildKeys).foreach {
81+
case (streamedKey, buildKey) =>
82+
val key = streamedKey.canonicalized
83+
mapping.get(key) match {
84+
case Some(v) => mapping.put(key, v :+ buildKey)
85+
case None => mapping.put(key, Seq(buildKey))
86+
}
87+
}
88+
mapping.toMap
89+
}
90+
91+
// Expands the given partitioning collection recursively.
92+
private def expandOutputPartitioning(
93+
partitioning: PartitioningCollection): PartitioningCollection = {
94+
PartitioningCollection(partitioning.partitionings.flatMap {
95+
case h: HashPartitioning => expandOutputPartitioning(h).partitionings
96+
case c: PartitioningCollection => Seq(expandOutputPartitioning(c))
97+
case other => Seq(other)
98+
})
99+
}
100+
101+
// Expands the given hash partitioning by substituting streamed keys with build keys.
102+
// For example, if the expressions for the given partitioning are Seq("a", "b", "c")
103+
// where the streamed keys are Seq("b", "c") and the build keys are Seq("x", "y"),
104+
// the expanded partitioning will have the following expressions:
105+
// Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y").
106+
// The expanded expressions are returned as PartitioningCollection.
107+
private def expandOutputPartitioning(partitioning: HashPartitioning): PartitioningCollection = {
108+
val maxNumCombinations = sqlContext.conf.broadcastHashJoinOutputPartitioningExpandLimit
109+
var currentNumCombinations = 0
110+
111+
def generateExprCombinations(
112+
current: Seq[Expression],
113+
accumulated: Seq[Expression]): Seq[Seq[Expression]] = {
114+
if (currentNumCombinations >= maxNumCombinations) {
115+
Nil
116+
} else if (current.isEmpty) {
117+
currentNumCombinations += 1
118+
Seq(accumulated)
119+
} else {
120+
val buildKeysOpt = streamedKeyToBuildKeyMapping.get(current.head.canonicalized)
121+
generateExprCombinations(current.tail, accumulated :+ current.head) ++
122+
buildKeysOpt.map(_.flatMap(b => generateExprCombinations(current.tail, accumulated :+ b)))
123+
.getOrElse(Nil)
124+
}
125+
}
126+
127+
PartitioningCollection(
128+
generateExprCombinations(partitioning.expressions, Nil)
129+
.map(HashPartitioning(_, partitioning.numPartitions)))
130+
}
131+
63132
protected override def doExecute(): RDD[InternalRow] = {
64133
val numOutputRows = longMetric("numOutputRows")
65134

@@ -135,13 +204,13 @@ case class BroadcastHashJoinExec(
135204
ctx: CodegenContext,
136205
input: Seq[ExprCode]): (ExprCode, String) = {
137206
ctx.currentVars = input
138-
if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) {
207+
if (streamedBoundKeys.length == 1 && streamedBoundKeys.head.dataType == LongType) {
139208
// generate the join key as Long
140-
val ev = streamedKeys.head.genCode(ctx)
209+
val ev = streamedBoundKeys.head.genCode(ctx)
141210
(ev, ev.isNull)
142211
} else {
143212
// generate the join key as UnsafeRow
144-
val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys)
213+
val ev = GenerateUnsafeProjection.createCode(ctx, streamedBoundKeys)
145214
(ev, s"${ev.value}.anyNull()")
146215
}
147216
}

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,21 +62,30 @@ trait HashJoin extends BaseJoinExec {
6262
protected lazy val (buildKeys, streamedKeys) = {
6363
require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType),
6464
"Join keys from two sides should have same types")
65-
val lkeys = bindReferences(HashJoin.rewriteKeyExpr(leftKeys), left.output)
66-
val rkeys = bindReferences(HashJoin.rewriteKeyExpr(rightKeys), right.output)
6765
buildSide match {
68-
case BuildLeft => (lkeys, rkeys)
69-
case BuildRight => (rkeys, lkeys)
66+
case BuildLeft => (leftKeys, rightKeys)
67+
case BuildRight => (rightKeys, leftKeys)
7068
}
7169
}
7270

71+
@transient private lazy val (buildOutput, streamedOutput) = {
72+
buildSide match {
73+
case BuildLeft => (left.output, right.output)
74+
case BuildRight => (right.output, left.output)
75+
}
76+
}
77+
78+
@transient protected lazy val buildBoundKeys =
79+
bindReferences(HashJoin.rewriteKeyExpr(buildKeys), buildOutput)
7380

81+
@transient protected lazy val streamedBoundKeys =
82+
bindReferences(HashJoin.rewriteKeyExpr(streamedKeys), streamedOutput)
7483

7584
protected def buildSideKeyGenerator(): Projection =
76-
UnsafeProjection.create(buildKeys)
85+
UnsafeProjection.create(buildBoundKeys)
7786

7887
protected def streamSideKeyGenerator(): UnsafeProjection =
79-
UnsafeProjection.create(streamedKeys)
88+
UnsafeProjection.create(streamedBoundKeys)
8089

8190
@transient private[this] lazy val boundCondition = if (condition.isDefined) {
8291
Predicate.create(condition.get, streamedPlan.output ++ buildPlan.output).eval _

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ case class ShuffledHashJoinExec(
5555
val buildTime = longMetric("buildTime")
5656
val start = System.nanoTime()
5757
val context = TaskContext.get()
58-
val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager())
58+
val relation = HashedRelation(
59+
iter, buildBoundKeys, taskMemoryManager = context.taskMemoryManager())
5960
buildTime += NANOSECONDS.toMillis(System.nanoTime() - start)
6061
buildDataSize += relation.estimatedSize
6162
// This relation is usually used until the end of task.

sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,8 @@ class AdaptiveQueryExecSuite
557557

558558
withSQLConf(
559559
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
560-
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") {
560+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80",
561+
SQLConf.BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT.key -> "0") {
561562
val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
562563
"SELECT * FROM testData " +
563564
"join testData2 t2 ON key = t2.a " +

0 commit comments

Comments
 (0)