Skip to content

Commit 479d56b

Browse files
committed
[SPARK-32201][SQL] More general skew join pattern matching
1 parent 5d296ed commit 479d56b

File tree

2 files changed

+196
-104
lines changed

2 files changed

+196
-104
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala

Lines changed: 135 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,21 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
144144
sizes.sum / sizes.length
145145
}
146146

147+
private def findShuffleStage(plan: SparkPlan): Option[ShuffleStageInfo] = {
148+
plan collectFirst {
149+
case _ @ ShuffleStage(shuffleStageInfo) =>
150+
shuffleStageInfo
151+
}
152+
}
153+
154+
private def replaceSkewedShufleReader(
155+
smj: SparkPlan, newCtm: CustomShuffleReaderExec): SparkPlan = {
156+
smj transformUp {
157+
case _ @ CustomShuffleReaderExec(child, _) if child.sameResult(newCtm.child) =>
158+
newCtm
159+
}
160+
}
161+
147162
/*
148163
* This method aim to optimize the skewed join with the following steps:
149164
* 1. Check whether the shuffle partition is skewed based on the median size
@@ -158,95 +173,107 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
158173
*/
159174
def optimizeSkewJoin(plan: SparkPlan): SparkPlan = plan.transformUp {
160175
case smj @ SortMergeJoinExec(_, _, joinType, _,
161-
s1 @ SortExec(_, _, ShuffleStage(left: ShuffleStageInfo), _),
162-
s2 @ SortExec(_, _, ShuffleStage(right: ShuffleStageInfo), _), _)
176+
s1 @ SortExec(_, _, _, _),
177+
s2 @ SortExec(_, _, _, _), _)
163178
if supportedJoinTypes.contains(joinType) =>
164-
assert(left.partitionsWithSizes.length == right.partitionsWithSizes.length)
165-
val numPartitions = left.partitionsWithSizes.length
166-
// We use the median size of the original shuffle partitions to detect skewed partitions.
167-
val leftMedSize = medianSize(left.mapStats)
168-
val rightMedSize = medianSize(right.mapStats)
169-
logDebug(
170-
s"""
171-
|Optimizing skewed join.
172-
|Left side partitions size info:
173-
|${getSizeInfo(leftMedSize, left.mapStats.bytesByPartitionId)}
174-
|Right side partitions size info:
175-
|${getSizeInfo(rightMedSize, right.mapStats.bytesByPartitionId)}
176-
""".stripMargin)
177-
val canSplitLeft = canSplitLeftSide(joinType)
178-
val canSplitRight = canSplitRightSide(joinType)
179-
// We use the actual partition sizes (may be coalesced) to calculate target size, so that
180-
// the final data distribution is even (coalesced partitions + split partitions).
181-
val leftActualSizes = left.partitionsWithSizes.map(_._2)
182-
val rightActualSizes = right.partitionsWithSizes.map(_._2)
183-
val leftTargetSize = targetSize(leftActualSizes, leftMedSize)
184-
val rightTargetSize = targetSize(rightActualSizes, rightMedSize)
185-
186-
val leftSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
187-
val rightSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
188-
var numSkewedLeft = 0
189-
var numSkewedRight = 0
190-
for (partitionIndex <- 0 until numPartitions) {
191-
val leftActualSize = leftActualSizes(partitionIndex)
192-
val isLeftSkew = isSkewed(leftActualSize, leftMedSize) && canSplitLeft
193-
val leftPartSpec = left.partitionsWithSizes(partitionIndex)._1
194-
val isLeftCoalesced = leftPartSpec.startReducerIndex + 1 < leftPartSpec.endReducerIndex
195-
196-
val rightActualSize = rightActualSizes(partitionIndex)
197-
val isRightSkew = isSkewed(rightActualSize, rightMedSize) && canSplitRight
198-
val rightPartSpec = right.partitionsWithSizes(partitionIndex)._1
199-
val isRightCoalesced = rightPartSpec.startReducerIndex + 1 < rightPartSpec.endReducerIndex
200-
201-
// A skewed partition should never be coalesced, but skip it here just to be safe.
202-
val leftParts = if (isLeftSkew && !isLeftCoalesced) {
203-
val reducerId = leftPartSpec.startReducerIndex
204-
val skewSpecs = createSkewPartitionSpecs(
205-
left.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, leftTargetSize)
206-
if (skewSpecs.isDefined) {
207-
logDebug(s"Left side partition $partitionIndex " +
208-
s"(${FileUtils.byteCountToDisplaySize(leftActualSize)}) is skewed, " +
209-
s"split it into ${skewSpecs.get.length} parts.")
210-
numSkewedLeft += 1
179+
// find the shuffleStage from the plan tree
180+
val leftOpt = findShuffleStage(s1)
181+
val rightOpt = findShuffleStage(s2)
182+
if (leftOpt.isEmpty || rightOpt.isEmpty) {
183+
smj
184+
} else {
185+
val left = leftOpt.get
186+
val right = rightOpt.get
187+
assert(left.partitionsWithSizes.length == right.partitionsWithSizes.length)
188+
val numPartitions = left.partitionsWithSizes.length
189+
// We use the median size of the original shuffle partitions to detect skewed partitions.
190+
val leftMedSize = medianSize(left.mapStats)
191+
val rightMedSize = medianSize(right.mapStats)
192+
logDebug(
193+
s"""
194+
|Optimizing skewed join.
195+
|Left side partitions size info:
196+
|${getSizeInfo(leftMedSize, left.mapStats.bytesByPartitionId)}
197+
198+
|Right side partitio
199+
200+
|${getSizeInfo(rightMedSize, right.mapStats.bytesByPartitionId)}
201+
""".stripMargin)
202+
val canSplitLeft = canSplitLeftSide(joinType)
203+
val canSplitRight = canSplitRightSide(joinType)
204+
// We use the actual partition sizes (may be coalesced) to calculate target size, so that
205+
// the final data distribution is even (coalesced partitions + split partitions).
206+
val leftActualSizes = left.partitionsWithSizes.map(_._2)
207+
val rightActualSizes = right.partitionsWithSizes.map(_._2)
208+
val leftTargetSize = targetSize(leftActualSizes, leftMedSize)
209+
val rightTargetSize = targetSize(rightActualSizes, rightMedSize)
210+
211+
val leftSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
212+
val rightSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
213+
var numSkewedLeft = 0
214+
var numSkewedRight = 0
215+
for (partitionIndex <- 0 until numPartitions) {
216+
val leftActualSize = leftActualSizes(partitionIndex)
217+
val isLeftSkew = isSkewed(leftActualSize, leftMedSize) && canSplitLeft
218+
val leftPartSpec = left.partitionsWithSizes(partitionIndex)._1
219+
val isLeftCoalesced = leftPartSpec.startReducerIndex + 1 < leftPartSpec.endReducerIndex
220+
221+
val rightActualSize = rightActualSizes(partitionIndex)
222+
val isRightSkew = isSkewed(rightActualSize, rightMedSize) && canSplitRight
223+
val rightPartSpec = right.partitionsWithSizes(partitionIndex)._1
224+
val isRightCoalesced = rightPartSpec.startReducerIndex + 1 < rightPartSpec.endReducerIndex
225+
226+
// A skewed partition should never be coalesced, but skip it here just to be safe.
227+
val leftParts = if (isLeftSkew && !isLeftCoalesced) {
228+
val reducerId = leftPartSpec.startReducerIndex
229+
val skewSpecs = createSkewPartitionSpecs(
230+
left.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, leftTargetSize)
231+
if (skewSpecs.isDefined) {
232+
logDebug(s"Left side partition $partitionIndex " +
233+
s"(${FileUtils.byteCountToDisplaySize(leftActualSize)}) is skewed, " +
234+
s"split it into ${skewSpecs.get.length} parts.")
235+
numSkewedLeft += 1
236+
}
237+
skewSpecs.getOrElse(Seq(leftPartSpec))
238+
} else {
239+
Seq(leftPartSpec)
211240
}
212-
skewSpecs.getOrElse(Seq(leftPartSpec))
213-
} else {
214-
Seq(leftPartSpec)
215-
}
216241

217-
// A skewed partition should never be coalesced, but skip it here just to be safe.
218-
val rightParts = if (isRightSkew && !isRightCoalesced) {
219-
val reducerId = rightPartSpec.startReducerIndex
220-
val skewSpecs = createSkewPartitionSpecs(
221-
right.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, rightTargetSize)
222-
if (skewSpecs.isDefined) {
223-
logDebug(s"Right side partition $partitionIndex " +
224-
s"(${FileUtils.byteCountToDisplaySize(rightActualSize)}) is skewed, " +
225-
s"split it into ${skewSpecs.get.length} parts.")
226-
numSkewedRight += 1
242+
// A skewed partition should never be coalesced, but skip it here just to be safe.
243+
val rightParts = if (isRightSkew && !isRightCoalesced) {
244+
val reducerId = rightPartSpec.startReducerIndex
245+
val skewSpecs = createSkewPartitionSpecs(
246+
right.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, rightTargetSize)
247+
if (skewSpecs.isDefined) {
248+
logDebug(s"Right side partition $partitionIndex " +
249+
s"(${FileUtils.byteCountToDisplaySize(rightActualSize)}) is skewed, " +
250+
s"split it into ${skewSpecs.get.length} parts.")
251+
numSkewedRight += 1
252+
}
253+
skewSpecs.getOrElse(Seq(rightPartSpec))
254+
} else {
255+
Seq(rightPartSpec)
227256
}
228-
skewSpecs.getOrElse(Seq(rightPartSpec))
229-
} else {
230-
Seq(rightPartSpec)
231-
}
232257

233-
for {
234-
leftSidePartition <- leftParts
235-
rightSidePartition <- rightParts
236-
} {
237-
leftSidePartitions += leftSidePartition
238-
rightSidePartitions += rightSidePartition
258+
for {
259+
leftSidePartition <- leftParts
260+
rightSidePartition <- rightParts
261+
} {
262+
leftSidePartitions += leftSidePartition
263+
rightSidePartitions += rightSidePartition
264+
}
239265
}
240-
}
241266

242-
logDebug(s"number of skewed partitions: left $numSkewedLeft, right $numSkewedRight")
243-
if (numSkewedLeft > 0 || numSkewedRight > 0) {
244-
val newLeft = CustomShuffleReaderExec(left.shuffleStage, leftSidePartitions)
245-
val newRight = CustomShuffleReaderExec(right.shuffleStage, rightSidePartitions)
246-
smj.copy(
247-
left = s1.copy(child = newLeft), right = s2.copy(child = newRight), isSkewJoin = true)
248-
} else {
249-
smj
267+
logDebug(s"number of skewed partitions: left $numSkewedLeft, right $numSkewedRight")
268+
if (numSkewedLeft > 0 || numSkewedRight > 0) {
269+
val newLeft = CustomShuffleReaderExec(left.shuffleStage, leftSidePartitions)
270+
val newRight = CustomShuffleReaderExec(right.shuffleStage, rightSidePartitions)
271+
val newSmj = replaceSkewedShufleReader(
272+
replaceSkewedShufleReader(smj, newLeft), newRight).asInstanceOf[SortMergeJoinExec]
273+
newSmj.copy(isSkewJoin = true)
274+
} else {
275+
smj
276+
}
250277
}
251278
}
252279

@@ -263,15 +290,19 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
263290
val shuffleStages = collectShuffleStages(plan)
264291

265292
if (shuffleStages.length == 2) {
266-
// When multi table join, there will be too many complex combination to consider.
267-
// Currently we only handle 2 table join like following use case.
293+
// SPARK-32201. Skew join supports below pattern, ".." may contain any number of nodes,
294+
// includes such as BroadcastHashJoinExec. So it can handle more than two tables join.
268295
// SMJ
269296
// Sort
270-
// Shuffle
297+
// ..
298+
// Shuffle
271299
// Sort
272-
// Shuffle
300+
// ..
301+
// Shuffle
273302
val optimizePlan = optimizeSkewJoin(plan)
274-
val numShuffles = ensureRequirements.apply(optimizePlan).collect {
303+
val ensuredPlan = ensureRequirements.apply(optimizePlan)
304+
println(ensuredPlan)
305+
val numShuffles = ensuredPlan.collect {
275306
case e: ShuffleExchangeExec => e
276307
}.length
277308

@@ -316,6 +347,23 @@ private object ShuffleStage {
316347
}
317348
Some(ShuffleStageInfo(s, mapStats, partitions))
318349

350+
case _: LeafExecNode => None
351+
352+
case _ @ UnaryExecNode((_, ShuffleStage(ss: ShuffleStageInfo))) =>
353+
Some(ss)
354+
355+
case b: BinaryExecNode =>
356+
b.left match {
357+
case _ @ ShuffleStage(ss: ShuffleStageInfo) =>
358+
Some(ss)
359+
case _ =>
360+
b.right match {
361+
case _ @ ShuffleStage(ss: ShuffleStageInfo) =>
362+
Some(ss)
363+
case _ => None
364+
}
365+
}
366+
319367
case _ => None
320368
}
321369
}

sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala

Lines changed: 61 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -692,23 +692,6 @@ class AdaptiveQueryExecSuite
692692
'id as "value2")
693693
.createOrReplaceTempView("skewData2")
694694

695-
def checkSkewJoin(
696-
joins: Seq[SortMergeJoinExec],
697-
leftSkewNum: Int,
698-
rightSkewNum: Int): Unit = {
699-
assert(joins.size == 1 && joins.head.isSkewJoin)
700-
assert(joins.head.left.collect {
701-
case r: CustomShuffleReaderExec => r
702-
}.head.partitionSpecs.collect {
703-
case p: PartialReducerPartitionSpec => p.reducerIndex
704-
}.distinct.length == leftSkewNum)
705-
assert(joins.head.right.collect {
706-
case r: CustomShuffleReaderExec => r
707-
}.head.partitionSpecs.collect {
708-
case p: PartialReducerPartitionSpec => p.reducerIndex
709-
}.distinct.length == rightSkewNum)
710-
}
711-
712695
// skewed inner join optimization
713696
val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(
714697
"SELECT * FROM skewData1 join skewData2 ON key1 = key2")
@@ -730,6 +713,67 @@ class AdaptiveQueryExecSuite
730713
}
731714
}
732715

716+
private def checkSkewJoin(
717+
joins: Seq[SortMergeJoinExec],
718+
leftSkewNum: Int,
719+
rightSkewNum: Int): Unit = {
720+
assert(joins.size == 1 && joins.head.isSkewJoin)
721+
assert(joins.head.left.collect {
722+
case r: CustomShuffleReaderExec => r
723+
}.head.partitionSpecs.collect {
724+
case p: PartialReducerPartitionSpec => p.reducerIndex
725+
}.distinct.length == leftSkewNum)
726+
assert(joins.head.right.collect {
727+
case r: CustomShuffleReaderExec => r
728+
}.head.partitionSpecs.collect {
729+
case p: PartialReducerPartitionSpec => p.reducerIndex
730+
}.distinct.length == rightSkewNum)
731+
}
732+
733+
test("SPARK-32201: handle general skew join pattern") {
734+
withSQLConf(
735+
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
736+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
737+
SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
738+
SQLConf.SHUFFLE_PARTITIONS.key -> "100",
739+
SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "800",
740+
SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "800") {
741+
withTempView("skewData1", "skewData2") {
742+
spark
743+
.range(0, 1000, 1, 10)
744+
.select(
745+
when('id < 250, 249)
746+
.when('id >= 750, 1000)
747+
.otherwise('id).as("key1"),
748+
'id as "value1")
749+
.createOrReplaceTempView("skewData1")
750+
751+
spark
752+
.range(0, 1000, 1, 10)
753+
.select(
754+
when('id < 250, 249)
755+
.otherwise('id).as("key2"),
756+
'id as "value2")
757+
.createOrReplaceTempView("skewData2")
758+
val sqlText =
759+
"""
760+
|SELECT * FROM
761+
| skewData1 AS data1
762+
| INNER JOIN
763+
| (
764+
| SELECT skewData2.key2, sum(skewData2.value2) AS sum2
765+
| FROM skewData2 GROUP BY skewData2.key2
766+
| ) AS data2
767+
|ON data1.key1 = data2.key2
768+
|""".stripMargin
769+
770+
val (_, adaptivePlan) = runAdaptiveAndVerifyResult(sqlText)
771+
val innerSmj = findTopLevelSortMergeJoin(adaptivePlan)
772+
checkSkewJoin(innerSmj, 2, 0)
773+
}
774+
}
775+
}
776+
733777
test("SPARK-30291: AQE should catch the exceptions when doing materialize") {
734778
withSQLConf(
735779
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {

0 commit comments

Comments
 (0)