@@ -144,6 +144,21 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
144144 sizes.sum / sizes.length
145145 }
146146
147+ private def findShuffleStage (plan : SparkPlan ): Option [ShuffleStageInfo ] = {
148+ plan collectFirst {
149+ case _ @ ShuffleStage (shuffleStageInfo) =>
150+ shuffleStageInfo
151+ }
152+ }
153+
154+ private def replaceSkewedShufleReader (
155+ smj : SparkPlan , newCtm : CustomShuffleReaderExec ): SparkPlan = {
156+ smj transformUp {
157+ case _ @ CustomShuffleReaderExec (child, _) if child.sameResult(newCtm.child) =>
158+ newCtm
159+ }
160+ }
161+
147162 /*
148163 * This method aim to optimize the skewed join with the following steps:
149164 * 1. Check whether the shuffle partition is skewed based on the median size
@@ -158,95 +173,107 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
158173 */
159174 def optimizeSkewJoin (plan : SparkPlan ): SparkPlan = plan.transformUp {
160175 case smj @ SortMergeJoinExec (_, _, joinType, _,
161- s1 @ SortExec (_, _, ShuffleStage ( left : ShuffleStageInfo ) , _),
162- s2 @ SortExec (_, _, ShuffleStage ( right : ShuffleStageInfo ) , _), _)
176+ s1 @ SortExec (_, _, _ , _),
177+ s2 @ SortExec (_, _, _ , _), _)
163178 if supportedJoinTypes.contains(joinType) =>
164- assert(left.partitionsWithSizes.length == right.partitionsWithSizes.length)
165- val numPartitions = left.partitionsWithSizes.length
166- // We use the median size of the original shuffle partitions to detect skewed partitions.
167- val leftMedSize = medianSize(left.mapStats)
168- val rightMedSize = medianSize(right.mapStats)
169- logDebug(
170- s """
171- |Optimizing skewed join.
172- |Left side partitions size info:
173- | ${getSizeInfo(leftMedSize, left.mapStats.bytesByPartitionId)}
174- |Right side partitions size info:
175- | ${getSizeInfo(rightMedSize, right.mapStats.bytesByPartitionId)}
176- """ .stripMargin)
177- val canSplitLeft = canSplitLeftSide(joinType)
178- val canSplitRight = canSplitRightSide(joinType)
179- // We use the actual partition sizes (may be coalesced) to calculate target size, so that
180- // the final data distribution is even (coalesced partitions + split partitions).
181- val leftActualSizes = left.partitionsWithSizes.map(_._2)
182- val rightActualSizes = right.partitionsWithSizes.map(_._2)
183- val leftTargetSize = targetSize(leftActualSizes, leftMedSize)
184- val rightTargetSize = targetSize(rightActualSizes, rightMedSize)
185-
186- val leftSidePartitions = mutable.ArrayBuffer .empty[ShufflePartitionSpec ]
187- val rightSidePartitions = mutable.ArrayBuffer .empty[ShufflePartitionSpec ]
188- var numSkewedLeft = 0
189- var numSkewedRight = 0
190- for (partitionIndex <- 0 until numPartitions) {
191- val leftActualSize = leftActualSizes(partitionIndex)
192- val isLeftSkew = isSkewed(leftActualSize, leftMedSize) && canSplitLeft
193- val leftPartSpec = left.partitionsWithSizes(partitionIndex)._1
194- val isLeftCoalesced = leftPartSpec.startReducerIndex + 1 < leftPartSpec.endReducerIndex
195-
196- val rightActualSize = rightActualSizes(partitionIndex)
197- val isRightSkew = isSkewed(rightActualSize, rightMedSize) && canSplitRight
198- val rightPartSpec = right.partitionsWithSizes(partitionIndex)._1
199- val isRightCoalesced = rightPartSpec.startReducerIndex + 1 < rightPartSpec.endReducerIndex
200-
201- // A skewed partition should never be coalesced, but skip it here just to be safe.
202- val leftParts = if (isLeftSkew && ! isLeftCoalesced) {
203- val reducerId = leftPartSpec.startReducerIndex
204- val skewSpecs = createSkewPartitionSpecs(
205- left.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, leftTargetSize)
206- if (skewSpecs.isDefined) {
207- logDebug(s " Left side partition $partitionIndex " +
208- s " ( ${FileUtils .byteCountToDisplaySize(leftActualSize)}) is skewed, " +
209- s " split it into ${skewSpecs.get.length} parts. " )
210- numSkewedLeft += 1
179+ // find the shuffleStage from the plan tree
180+ val leftOpt = findShuffleStage(s1)
181+ val rightOpt = findShuffleStage(s2)
182+ if (leftOpt.isEmpty || rightOpt.isEmpty) {
183+ smj
184+ } else {
185+ val left = leftOpt.get
186+ val right = rightOpt.get
187+ assert(left.partitionsWithSizes.length == right.partitionsWithSizes.length)
188+ val numPartitions = left.partitionsWithSizes.length
189+ // We use the median size of the original shuffle partitions to detect skewed partitions.
190+ val leftMedSize = medianSize(left.mapStats)
191+ val rightMedSize = medianSize(right.mapStats)
192+ logDebug(
193+ s """
194+ |Optimizing skewed join.
195+ |Left side partitions size info:
196+ | ${getSizeInfo(leftMedSize, left.mapStats.bytesByPartitionId)}
197+
198+ |Right side partitio
199+
200+ | ${getSizeInfo(rightMedSize, right.mapStats.bytesByPartitionId)}
201+ """ .stripMargin)
202+ val canSplitLeft = canSplitLeftSide(joinType)
203+ val canSplitRight = canSplitRightSide(joinType)
204+ // We use the actual partition sizes (may be coalesced) to calculate target size, so that
205+ // the final data distribution is even (coalesced partitions + split partitions).
206+ val leftActualSizes = left.partitionsWithSizes.map(_._2)
207+ val rightActualSizes = right.partitionsWithSizes.map(_._2)
208+ val leftTargetSize = targetSize(leftActualSizes, leftMedSize)
209+ val rightTargetSize = targetSize(rightActualSizes, rightMedSize)
210+
211+ val leftSidePartitions = mutable.ArrayBuffer .empty[ShufflePartitionSpec ]
212+ val rightSidePartitions = mutable.ArrayBuffer .empty[ShufflePartitionSpec ]
213+ var numSkewedLeft = 0
214+ var numSkewedRight = 0
215+ for (partitionIndex <- 0 until numPartitions) {
216+ val leftActualSize = leftActualSizes(partitionIndex)
217+ val isLeftSkew = isSkewed(leftActualSize, leftMedSize) && canSplitLeft
218+ val leftPartSpec = left.partitionsWithSizes(partitionIndex)._1
219+ val isLeftCoalesced = leftPartSpec.startReducerIndex + 1 < leftPartSpec.endReducerIndex
220+
221+ val rightActualSize = rightActualSizes(partitionIndex)
222+ val isRightSkew = isSkewed(rightActualSize, rightMedSize) && canSplitRight
223+ val rightPartSpec = right.partitionsWithSizes(partitionIndex)._1
224+ val isRightCoalesced = rightPartSpec.startReducerIndex + 1 < rightPartSpec.endReducerIndex
225+
226+ // A skewed partition should never be coalesced, but skip it here just to be safe.
227+ val leftParts = if (isLeftSkew && ! isLeftCoalesced) {
228+ val reducerId = leftPartSpec.startReducerIndex
229+ val skewSpecs = createSkewPartitionSpecs(
230+ left.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, leftTargetSize)
231+ if (skewSpecs.isDefined) {
232+ logDebug(s " Left side partition $partitionIndex " +
233+ s " ( ${FileUtils .byteCountToDisplaySize(leftActualSize)}) is skewed, " +
234+ s " split it into ${skewSpecs.get.length} parts. " )
235+ numSkewedLeft += 1
236+ }
237+ skewSpecs.getOrElse(Seq (leftPartSpec))
238+ } else {
239+ Seq (leftPartSpec)
211240 }
212- skewSpecs.getOrElse(Seq (leftPartSpec))
213- } else {
214- Seq (leftPartSpec)
215- }
216241
217- // A skewed partition should never be coalesced, but skip it here just to be safe.
218- val rightParts = if (isRightSkew && ! isRightCoalesced) {
219- val reducerId = rightPartSpec.startReducerIndex
220- val skewSpecs = createSkewPartitionSpecs(
221- right.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, rightTargetSize)
222- if (skewSpecs.isDefined) {
223- logDebug(s " Right side partition $partitionIndex " +
224- s " ( ${FileUtils .byteCountToDisplaySize(rightActualSize)}) is skewed, " +
225- s " split it into ${skewSpecs.get.length} parts. " )
226- numSkewedRight += 1
242+ // A skewed partition should never be coalesced, but skip it here just to be safe.
243+ val rightParts = if (isRightSkew && ! isRightCoalesced) {
244+ val reducerId = rightPartSpec.startReducerIndex
245+ val skewSpecs = createSkewPartitionSpecs(
246+ right.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, rightTargetSize)
247+ if (skewSpecs.isDefined) {
248+ logDebug(s " Right side partition $partitionIndex " +
249+ s " ( ${FileUtils .byteCountToDisplaySize(rightActualSize)}) is skewed, " +
250+ s " split it into ${skewSpecs.get.length} parts. " )
251+ numSkewedRight += 1
252+ }
253+ skewSpecs.getOrElse(Seq (rightPartSpec))
254+ } else {
255+ Seq (rightPartSpec)
227256 }
228- skewSpecs.getOrElse(Seq (rightPartSpec))
229- } else {
230- Seq (rightPartSpec)
231- }
232257
233- for {
234- leftSidePartition <- leftParts
235- rightSidePartition <- rightParts
236- } {
237- leftSidePartitions += leftSidePartition
238- rightSidePartitions += rightSidePartition
258+ for {
259+ leftSidePartition <- leftParts
260+ rightSidePartition <- rightParts
261+ } {
262+ leftSidePartitions += leftSidePartition
263+ rightSidePartitions += rightSidePartition
264+ }
239265 }
240- }
241266
242- logDebug(s " number of skewed partitions: left $numSkewedLeft, right $numSkewedRight" )
243- if (numSkewedLeft > 0 || numSkewedRight > 0 ) {
244- val newLeft = CustomShuffleReaderExec (left.shuffleStage, leftSidePartitions)
245- val newRight = CustomShuffleReaderExec (right.shuffleStage, rightSidePartitions)
246- smj.copy(
247- left = s1.copy(child = newLeft), right = s2.copy(child = newRight), isSkewJoin = true )
248- } else {
249- smj
267+ logDebug(s " number of skewed partitions: left $numSkewedLeft, right $numSkewedRight" )
268+ if (numSkewedLeft > 0 || numSkewedRight > 0 ) {
269+ val newLeft = CustomShuffleReaderExec (left.shuffleStage, leftSidePartitions)
270+ val newRight = CustomShuffleReaderExec (right.shuffleStage, rightSidePartitions)
271+ val newSmj = replaceSkewedShufleReader(
272+ replaceSkewedShufleReader(smj, newLeft), newRight).asInstanceOf [SortMergeJoinExec ]
273+ newSmj.copy(isSkewJoin = true )
274+ } else {
275+ smj
276+ }
250277 }
251278 }
252279
@@ -263,15 +290,19 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
263290 val shuffleStages = collectShuffleStages(plan)
264291
265292 if (shuffleStages.length == 2 ) {
266- // When multi table join, there will be too many complex combination to consider.
267- // Currently we only handle 2 table join like following use case .
293+ // SPARK-32201. Skew join supports below pattern, ".." may contain any number of nodes,
294+ // includes such as BroadcastHashJoinExec. So it can handle more than two tables join .
268295 // SMJ
269296 // Sort
270- // Shuffle
297+ // ..
298+ // Shuffle
271299 // Sort
272- // Shuffle
300+ // ..
301+ // Shuffle
273302 val optimizePlan = optimizeSkewJoin(plan)
274- val numShuffles = ensureRequirements.apply(optimizePlan).collect {
303+ val ensuredPlan = ensureRequirements.apply(optimizePlan)
304+ println(ensuredPlan)
305+ val numShuffles = ensuredPlan.collect {
275306 case e : ShuffleExchangeExec => e
276307 }.length
277308
@@ -316,6 +347,23 @@ private object ShuffleStage {
316347 }
317348 Some (ShuffleStageInfo (s, mapStats, partitions))
318349
350+ case _ : LeafExecNode => None
351+
352+ case _ @ UnaryExecNode ((_, ShuffleStage (ss : ShuffleStageInfo ))) =>
353+ Some (ss)
354+
355+ case b : BinaryExecNode =>
356+ b.left match {
357+ case _ @ ShuffleStage (ss : ShuffleStageInfo ) =>
358+ Some (ss)
359+ case _ =>
360+ b.right match {
361+ case _ @ ShuffleStage (ss : ShuffleStageInfo ) =>
362+ Some (ss)
363+ case _ => None
364+ }
365+ }
366+
319367 case _ => None
320368 }
321369}
0 commit comments