@@ -23,26 +23,37 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
2323import org .apache .spark .sql .catalyst .plans .{LeftAnti , LeftOuter , RightOuter }
2424import org .apache .spark .sql .catalyst .plans .logical .{HintInfo , Join , JoinStrategyHint , LogicalPlan , NO_BROADCAST_HASH , PREFER_SHUFFLE_HASH , SHUFFLE_HASH }
2525import org .apache .spark .sql .catalyst .rules .Rule
26- import org .apache .spark .sql .catalyst .trees .TreeNodeTag
27- import org .apache .spark .sql .execution .{CoalescedPartitionSpec , SparkPlan }
26+ import org .apache .spark .sql .execution .CoalescedPartitionSpec
2827import org .apache .spark .sql .execution .adaptive .OptimizeSkewedJoin .getSkewThreshold
2928import org .apache .spark .sql .internal .SQLConf
3029import org .apache .spark .util .Utils
3130
3231/**
3332 * This optimization rule includes three join selection:
34- * 1. detects a join child that has a high ratio of empty partitions and adds a
33+ * 1. Do not add any until all the children are materialized and don't need additional shuffle.
34+ * 1.1 Won't select any join strategy for the following query as it need additional shuffle
35+ * after all ShuffleQueryStageExec are materialized
36+ * SortMergeJoin
37+ * :- ShuffleQueryStageExec (hashpartitioning(ID#1, 10000))
38+ * +- SortMergeJoin
39+ * :- ShuffleQueryStageExec (hashpartitioning(ID#2, 500))
40+ * +- ShuffleQueryStageExec (hashpartitioning(ID#3, 500))
41+ *
42+ * 1.2 Won't select any join strategy if the other side contains bucket table.
43+ * SortMergeJoin
44+ * :- ShuffleQueryStageExec (hashpartitioning(ID#1, 10000))
45+ * +- BucketTableScan
46+ *
47+ * 2. detects a join child that has a high ratio of empty partitions and adds a
3548 * NO_BROADCAST_HASH hint to avoid it being broadcast, as shuffle join is faster in this case:
3649 * many tasks complete immediately since one join side is empty.
37- * 2 . detects a join child that every partition size is less than local map threshold and adds a
50+ * 3 . detects a join child that every partition size is less than local map threshold and adds a
3851 * PREFER_SHUFFLE_HASH hint to encourage being shuffle hash join instead of sort merge join.
39- * 3 . if a join satisfies both NO_BROADCAST_HASH and PREFER_SHUFFLE_HASH,
52+ * 4 . if a join satisfies both NO_BROADCAST_HASH and PREFER_SHUFFLE_HASH,
4053 * then add a SHUFFLE_HASH hint.
4154 */
4255object DynamicJoinSelection extends Rule [LogicalPlan ] with JoinSelectionHelper {
4356
44- val USER_DEFINED_HINT_TAG = TreeNodeTag [Boolean ](" USER_DEFINED_HINT" )
45-
4657 private def hasManyEmptyPartitions (mapStats : MapOutputStatistics ): Boolean = {
4758 val partitionCnt = mapStats.bytesByPartitionId.length
4859 val nonZeroCnt = mapStats.bytesByPartitionId.count(_ > 0 )
@@ -65,22 +76,20 @@ object DynamicJoinSelection extends Rule[LogicalPlan] with JoinSelectionHelper {
6576 streamedStats : Seq [MapOutputStatistics ]): Boolean = {
6677 val maxShuffledHashJoinLocalMapThreshold =
6778 conf.getConf(SQLConf .ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD )
79+ val advisoryStreamPartitionSize =
80+ conf.getConf(SQLConf .ADAPTIVE_SHUFFLE_HASH_JOIN_ADVISORY_STREAM_PARTITION_SIZE )
6881 // If the join is skew, since CARMEL will not handle SHJ skew join, and we are not sure SHJ
6982 // will be faster better SMJ for the left skew join patten, so do not convert to SHJ if any
7083 // join side is skew.
71- if (maxShuffledHashJoinLocalMapThreshold <= 0 || streamedStats.exists(isSkew(_))) {
84+ if (maxShuffledHashJoinLocalMapThreshold <= 0 || advisoryStreamPartitionSize <= 0 ||
85+ streamedStats.exists(isSkew(_))) {
7286 return false
7387 }
74- val partitionSpecs = ShufflePartitionsUtil .coalescePartitions(
75- Array (mapStats) ++ streamedStats,
76- advisoryTargetSize = conf.getConf(SQLConf .ADVISORY_PARTITION_SIZE_IN_BYTES ),
77- minNumPartitions = 0 )
78- partitionSpecs.nonEmpty &&
79- partitionSpecs.forall(_.isInstanceOf [CoalescedPartitionSpec ]) &&
80- partitionSpecs.collect {
81- case CoalescedPartitionSpec (startReducerIndex, endReducerIndex) =>
82- mapStats.bytesByPartitionId.slice(startReducerIndex, endReducerIndex).sum
83- }.forall(_ <= maxShuffledHashJoinLocalMapThreshold)
88+
89+ mapStats.bytesByPartitionId.forall(_ <= maxShuffledHashJoinLocalMapThreshold) &&
90+ streamedStats.filter(_.bytesByPartitionId.length > 0 ).exists { stats =>
91+ Utils .median(stats.bytesByPartitionId, false ) > advisoryStreamPartitionSize
92+ }
8493 }
8594
8695 private def selectJoinStrategy (
@@ -123,16 +132,29 @@ object DynamicJoinSelection extends Rule[LogicalPlan] with JoinSelectionHelper {
123132 stage.computeStats().exists(_.rowCount.exists(_.toLong >= conf.broadcastMaxRowNum))
124133 val adjustDemoteBroadcastHash = rowNumberExceeded || demoteBroadcastHash
125134
126- def collectShuffleStats (plan : LogicalPlan ): Seq [MapOutputStatistics ] = plan match {
135+ var bucketedPlan = false
136+ def collectShuffleStats (plan : LogicalPlan ): Seq [Option [MapOutputStatistics ]] = plan match {
127137 case LogicalQueryStage (_, streamedStage : ShuffleQueryStageExec )
128138 if streamedStage.isMaterialized && streamedStage.mapStats.isDefined =>
129- Seq (streamedStage.mapStats.get)
130- case _ => plan.children.flatMap(collectShuffleStats)
139+ Seq (streamedStage.mapStats)
140+ case LogicalQueryStage (_, _ : ShuffleQueryStageExec ) => Seq (None )
141+ case _ if plan.children.nonEmpty => plan.children.flatMap(collectShuffleStats)
142+ case _ =>
143+ bucketedPlan = true
144+ Seq ()
131145 }
132- val preferShuffleHash =
133- preferShuffledHashJoin(stage.mapStats.get, collectShuffleStats(streamedPlan))
134146
135- logInfo(s " canBroadcastPlan = $canBroadcastPlan, rowNumberExceeded = " +
147+ val streamedStats = collectShuffleStats(streamedPlan)
148+ val allStats = Array (stage.mapStats) ++ streamedStats
149+
150+ val shuffleMaterialized =
151+ allStats.forall(_.isDefined) &&
152+ allStats.map(_.get.bytesByPartitionId.length).distinct.length == 1
153+ val preferShuffleHash = ! bucketedPlan && shuffleMaterialized &&
154+ preferShuffledHashJoin(stage.mapStats.get, streamedStats.map(_.get))
155+
156+ logInfo(s " isLeft = $isLeft, shuffleMaterialized = $shuffleMaterialized, " +
157+ s " canBroadcastPlan = $canBroadcastPlan, rowNumberExceeded = " +
136158 s " $rowNumberExceeded, adjustDemoteBroadcastHash = $adjustDemoteBroadcastHash, " +
137159 s " preferShuffleHash = $preferShuffleHash" )
138160 if (adjustDemoteBroadcastHash && preferShuffleHash) {
@@ -150,24 +172,18 @@ object DynamicJoinSelection extends Rule[LogicalPlan] with JoinSelectionHelper {
150172 }
151173
152174 def apply (plan : LogicalPlan ): LogicalPlan = plan.transformDown {
153- case j @ ExtractEquiJoinKeys (_, _, _, _, left, right, hint) =>
154- if (left.getTagValue(USER_DEFINED_HINT_TAG ).isEmpty) {
155- left.setTagValue(USER_DEFINED_HINT_TAG , hint.leftHint.exists(_.strategy.isDefined))
156- }
157- if (right.getTagValue(USER_DEFINED_HINT_TAG ).isEmpty) {
158- right.setTagValue(USER_DEFINED_HINT_TAG , hint.rightHint.exists(_.strategy.isDefined))
159- }
175+ case j @ ExtractEquiJoinKeys (_, _, _, _, _, _, hint) =>
160176 var newHint = hint
161- if (! left.getTagValue(USER_DEFINED_HINT_TAG ).getOrElse(false )) {
177+ if (! hint.leftHint.exists(_.strategy.isDefined) ||
178+ hint.leftHint.get.strategy.contains(NO_BROADCAST_HASH )) {
162179 selectJoinStrategy(j, true ).foreach { strategy =>
163- logInfo(s " Set left side join strategy: $strategy" )
164180 newHint = newHint.copy(leftHint =
165181 Some (hint.leftHint.getOrElse(HintInfo ()).copy(strategy = Some (strategy))))
166182 }
167183 }
168- if (! right.getTagValue(USER_DEFINED_HINT_TAG ).getOrElse(false )) {
184+ if (! hint.rightHint.exists(_.strategy.isDefined) ||
185+ hint.rightHint.get.strategy.contains(NO_BROADCAST_HASH )) {
169186 selectJoinStrategy(j, false ).foreach { strategy =>
170- logInfo(s " Set right side join strategy: $strategy" )
171187 newHint = newHint.copy(rightHint =
172188 Some (hint.rightHint.getOrElse(HintInfo ()).copy(strategy = Some (strategy))))
173189 }
0 commit comments