Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,21 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
sizes.sum / sizes.length
}

private def findShuffleStage(plan: SparkPlan): Option[ShuffleStageInfo] = {
plan collectFirst {
case _ @ ShuffleStage(shuffleStageInfo) =>
shuffleStageInfo
}
}

private def replaceSkewedShufleReader(
smj: SparkPlan, newCtm: CustomShuffleReaderExec): SparkPlan = {
smj transformUp {
case _ @ CustomShuffleReaderExec(child, _) if child.sameResult(newCtm.child) =>
newCtm
}
}

/*
* This method aim to optimize the skewed join with the following steps:
* 1. Check whether the shuffle partition is skewed based on the median size
Expand All @@ -158,95 +173,107 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
*/
def optimizeSkewJoin(plan: SparkPlan): SparkPlan = plan.transformUp {
case smj @ SortMergeJoinExec(_, _, joinType, _,
s1 @ SortExec(_, _, ShuffleStage(left: ShuffleStageInfo), _),
s2 @ SortExec(_, _, ShuffleStage(right: ShuffleStageInfo), _), _)
s1 @ SortExec(_, _, _, _),
s2 @ SortExec(_, _, _, _), _)
if supportedJoinTypes.contains(joinType) =>
assert(left.partitionsWithSizes.length == right.partitionsWithSizes.length)
val numPartitions = left.partitionsWithSizes.length
// We use the median size of the original shuffle partitions to detect skewed partitions.
val leftMedSize = medianSize(left.mapStats)
val rightMedSize = medianSize(right.mapStats)
logDebug(
s"""
|Optimizing skewed join.
|Left side partitions size info:
|${getSizeInfo(leftMedSize, left.mapStats.bytesByPartitionId)}
|Right side partitions size info:
|${getSizeInfo(rightMedSize, right.mapStats.bytesByPartitionId)}
""".stripMargin)
val canSplitLeft = canSplitLeftSide(joinType)
val canSplitRight = canSplitRightSide(joinType)
// We use the actual partition sizes (may be coalesced) to calculate target size, so that
// the final data distribution is even (coalesced partitions + split partitions).
val leftActualSizes = left.partitionsWithSizes.map(_._2)
val rightActualSizes = right.partitionsWithSizes.map(_._2)
val leftTargetSize = targetSize(leftActualSizes, leftMedSize)
val rightTargetSize = targetSize(rightActualSizes, rightMedSize)

val leftSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
val rightSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
var numSkewedLeft = 0
var numSkewedRight = 0
for (partitionIndex <- 0 until numPartitions) {
val leftActualSize = leftActualSizes(partitionIndex)
val isLeftSkew = isSkewed(leftActualSize, leftMedSize) && canSplitLeft
val leftPartSpec = left.partitionsWithSizes(partitionIndex)._1
val isLeftCoalesced = leftPartSpec.startReducerIndex + 1 < leftPartSpec.endReducerIndex

val rightActualSize = rightActualSizes(partitionIndex)
val isRightSkew = isSkewed(rightActualSize, rightMedSize) && canSplitRight
val rightPartSpec = right.partitionsWithSizes(partitionIndex)._1
val isRightCoalesced = rightPartSpec.startReducerIndex + 1 < rightPartSpec.endReducerIndex

// A skewed partition should never be coalesced, but skip it here just to be safe.
val leftParts = if (isLeftSkew && !isLeftCoalesced) {
val reducerId = leftPartSpec.startReducerIndex
val skewSpecs = createSkewPartitionSpecs(
left.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, leftTargetSize)
if (skewSpecs.isDefined) {
logDebug(s"Left side partition $partitionIndex " +
s"(${FileUtils.byteCountToDisplaySize(leftActualSize)}) is skewed, " +
s"split it into ${skewSpecs.get.length} parts.")
numSkewedLeft += 1
// find the shuffleStage from the plan tree
val leftOpt = findShuffleStage(s1)
val rightOpt = findShuffleStage(s2)
if (leftOpt.isEmpty || rightOpt.isEmpty) {
smj
} else {
val left = leftOpt.get
val right = rightOpt.get
assert(left.partitionsWithSizes.length == right.partitionsWithSizes.length)
val numPartitions = left.partitionsWithSizes.length
// We use the median size of the original shuffle partitions to detect skewed partitions.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is very hard to reason about. We need to clearly define:

  1. what nodes can appear between the shuffle stage and SMJ. As we discussed before, Agg can't appear at the skew side.
  2. how to estimate the size? Since there are nodes in the middle, the stats of the shuffle stage may not be accurate for the final join child. (e.g. Filter in the middle)

Copy link
Contributor Author

@LantaoJin LantaoJin Jul 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. what nodes can appear between the shuffle stage and SMJ. As we discussed before, Agg can't appear at the skew side.

In the canSplitLeftSide and canSplitRightSide, I added a allUnspecifiedDistribution(plan) check. Current we only support the nodes with UnspecifiedDistribution.

  1. how to estimate the size? Since there are nodes in the middle, the stats of the shuffle stage may not be accurate for the final join child. (e.g. Filter in the middle)

Filter should be pushdown to leaf, I didn't see this user case. Project may be a command case in the middle? Yes. the input size of shuffle stage may not be accurate. But the disadvantage is launching more tasks. I think the benefit from handling the skewing is more important than the disadvantage.

val leftMedSize = medianSize(left.mapStats)
val rightMedSize = medianSize(right.mapStats)
logDebug(
s"""
|Optimizing skewed join.
|Left side partitions size info:
|${getSizeInfo(leftMedSize, left.mapStats.bytesByPartitionId)}

|Right side partitio

|${getSizeInfo(rightMedSize, right.mapStats.bytesByPartitionId)}
""".stripMargin)
val canSplitLeft = canSplitLeftSide(joinType)
val canSplitRight = canSplitRightSide(joinType)
// We use the actual partition sizes (may be coalesced) to calculate target size, so that
// the final data distribution is even (coalesced partitions + split partitions).
val leftActualSizes = left.partitionsWithSizes.map(_._2)
val rightActualSizes = right.partitionsWithSizes.map(_._2)
val leftTargetSize = targetSize(leftActualSizes, leftMedSize)
val rightTargetSize = targetSize(rightActualSizes, rightMedSize)

val leftSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
val rightSidePartitions = mutable.ArrayBuffer.empty[ShufflePartitionSpec]
var numSkewedLeft = 0
var numSkewedRight = 0
for (partitionIndex <- 0 until numPartitions) {
val leftActualSize = leftActualSizes(partitionIndex)
val isLeftSkew = isSkewed(leftActualSize, leftMedSize) && canSplitLeft
val leftPartSpec = left.partitionsWithSizes(partitionIndex)._1
val isLeftCoalesced = leftPartSpec.startReducerIndex + 1 < leftPartSpec.endReducerIndex

val rightActualSize = rightActualSizes(partitionIndex)
val isRightSkew = isSkewed(rightActualSize, rightMedSize) && canSplitRight
val rightPartSpec = right.partitionsWithSizes(partitionIndex)._1
val isRightCoalesced = rightPartSpec.startReducerIndex + 1 < rightPartSpec.endReducerIndex

// A skewed partition should never be coalesced, but skip it here just to be safe.
val leftParts = if (isLeftSkew && !isLeftCoalesced) {
val reducerId = leftPartSpec.startReducerIndex
val skewSpecs = createSkewPartitionSpecs(
left.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, leftTargetSize)
if (skewSpecs.isDefined) {
logDebug(s"Left side partition $partitionIndex " +
s"(${FileUtils.byteCountToDisplaySize(leftActualSize)}) is skewed, " +
s"split it into ${skewSpecs.get.length} parts.")
numSkewedLeft += 1
}
skewSpecs.getOrElse(Seq(leftPartSpec))
} else {
Seq(leftPartSpec)
}
skewSpecs.getOrElse(Seq(leftPartSpec))
} else {
Seq(leftPartSpec)
}

// A skewed partition should never be coalesced, but skip it here just to be safe.
val rightParts = if (isRightSkew && !isRightCoalesced) {
val reducerId = rightPartSpec.startReducerIndex
val skewSpecs = createSkewPartitionSpecs(
right.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, rightTargetSize)
if (skewSpecs.isDefined) {
logDebug(s"Right side partition $partitionIndex " +
s"(${FileUtils.byteCountToDisplaySize(rightActualSize)}) is skewed, " +
s"split it into ${skewSpecs.get.length} parts.")
numSkewedRight += 1
// A skewed partition should never be coalesced, but skip it here just to be safe.
val rightParts = if (isRightSkew && !isRightCoalesced) {
val reducerId = rightPartSpec.startReducerIndex
val skewSpecs = createSkewPartitionSpecs(
right.shuffleStage.shuffle.shuffleDependency.shuffleId, reducerId, rightTargetSize)
if (skewSpecs.isDefined) {
logDebug(s"Right side partition $partitionIndex " +
s"(${FileUtils.byteCountToDisplaySize(rightActualSize)}) is skewed, " +
s"split it into ${skewSpecs.get.length} parts.")
numSkewedRight += 1
}
skewSpecs.getOrElse(Seq(rightPartSpec))
} else {
Seq(rightPartSpec)
}
skewSpecs.getOrElse(Seq(rightPartSpec))
} else {
Seq(rightPartSpec)
}

for {
leftSidePartition <- leftParts
rightSidePartition <- rightParts
} {
leftSidePartitions += leftSidePartition
rightSidePartitions += rightSidePartition
for {
leftSidePartition <- leftParts
rightSidePartition <- rightParts
} {
leftSidePartitions += leftSidePartition
rightSidePartitions += rightSidePartition
}
}
}

logDebug(s"number of skewed partitions: left $numSkewedLeft, right $numSkewedRight")
if (numSkewedLeft > 0 || numSkewedRight > 0) {
val newLeft = CustomShuffleReaderExec(left.shuffleStage, leftSidePartitions)
val newRight = CustomShuffleReaderExec(right.shuffleStage, rightSidePartitions)
smj.copy(
left = s1.copy(child = newLeft), right = s2.copy(child = newRight), isSkewJoin = true)
} else {
smj
logDebug(s"number of skewed partitions: left $numSkewedLeft, right $numSkewedRight")
if (numSkewedLeft > 0 || numSkewedRight > 0) {
val newLeft = CustomShuffleReaderExec(left.shuffleStage, leftSidePartitions)
val newRight = CustomShuffleReaderExec(right.shuffleStage, rightSidePartitions)
val newSmj = replaceSkewedShufleReader(
replaceSkewedShufleReader(smj, newLeft), newRight).asInstanceOf[SortMergeJoinExec]
newSmj.copy(isSkewJoin = true)
} else {
smj
}
}
}

Expand All @@ -263,15 +290,19 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
val shuffleStages = collectShuffleStages(plan)

if (shuffleStages.length == 2) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not we break this limitation first?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this PR is not to address the case which has multiple SMJ. We have another PR to change this limitation:

  1. optimizeSingleStageSkewJoin. This is the case one table is a bucket table and the SMJ is bucketing join with one side shuffle and skewing
  2. optimizeThreeShuffleStageSkewJoin. This is to address three tables SMJ (Two SMJs in one stage and no one can be changed to BCJ in AQE).

// When multi table join, there will be too many complex combination to consider.
// Currently we only handle 2 table join like following use case.
// SPARK-32201. Skew join supports below pattern, ".." may contain any number of nodes,
// includes such as BroadcastHashJoinExec. So it can handle more than two tables join.
// SMJ
// Sort
// Shuffle
// ..
// Shuffle
// Sort
// Shuffle
// ..
// Shuffle
val optimizePlan = optimizeSkewJoin(plan)
val numShuffles = ensureRequirements.apply(optimizePlan).collect {
val ensuredPlan = ensureRequirements.apply(optimizePlan)
println(ensuredPlan)
val numShuffles = ensuredPlan.collect {
case e: ShuffleExchangeExec => e
}.length

Expand Down Expand Up @@ -316,6 +347,23 @@ private object ShuffleStage {
}
Some(ShuffleStageInfo(s, mapStats, partitions))

case _: LeafExecNode => None

case _ @ UnaryExecNode((_, ShuffleStage(ss: ShuffleStageInfo))) =>
Some(ss)

case b: BinaryExecNode =>
b.left match {
case _ @ ShuffleStage(ss: ShuffleStageInfo) =>
Some(ss)
case _ =>
b.right match {
case _ @ ShuffleStage(ss: ShuffleStageInfo) =>
Some(ss)
case _ => None
}
}

case _ => None
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -692,23 +692,6 @@ class AdaptiveQueryExecSuite
'id as "value2")
.createOrReplaceTempView("skewData2")

def checkSkewJoin(
joins: Seq[SortMergeJoinExec],
leftSkewNum: Int,
rightSkewNum: Int): Unit = {
assert(joins.size == 1 && joins.head.isSkewJoin)
assert(joins.head.left.collect {
case r: CustomShuffleReaderExec => r
}.head.partitionSpecs.collect {
case p: PartialReducerPartitionSpec => p.reducerIndex
}.distinct.length == leftSkewNum)
assert(joins.head.right.collect {
case r: CustomShuffleReaderExec => r
}.head.partitionSpecs.collect {
case p: PartialReducerPartitionSpec => p.reducerIndex
}.distinct.length == rightSkewNum)
}

// skewed inner join optimization
val (_, innerAdaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT * FROM skewData1 join skewData2 ON key1 = key2")
Expand All @@ -730,6 +713,67 @@ class AdaptiveQueryExecSuite
}
}

private def checkSkewJoin(
joins: Seq[SortMergeJoinExec],
leftSkewNum: Int,
rightSkewNum: Int): Unit = {
assert(joins.size == 1 && joins.head.isSkewJoin)
assert(joins.head.left.collect {
case r: CustomShuffleReaderExec => r
}.head.partitionSpecs.collect {
case p: PartialReducerPartitionSpec => p.reducerIndex
}.distinct.length == leftSkewNum)
assert(joins.head.right.collect {
case r: CustomShuffleReaderExec => r
}.head.partitionSpecs.collect {
case p: PartialReducerPartitionSpec => p.reducerIndex
}.distinct.length == rightSkewNum)
}

test("SPARK-32201: handle general skew join pattern") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1",
SQLConf.SHUFFLE_PARTITIONS.key -> "100",
SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD.key -> "800",
SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "800") {
withTempView("skewData1", "skewData2") {
spark
.range(0, 1000, 1, 10)
.select(
when('id < 250, 249)
.when('id >= 750, 1000)
.otherwise('id).as("key1"),
'id as "value1")
.createOrReplaceTempView("skewData1")

spark
.range(0, 1000, 1, 10)
.select(
when('id < 250, 249)
.otherwise('id).as("key2"),
'id as "value2")
.createOrReplaceTempView("skewData2")
val sqlText =
"""
|SELECT * FROM
| skewData1 AS data1
| INNER JOIN
| (
| SELECT skewData2.key2, sum(skewData2.value2) AS sum2
| FROM skewData2 GROUP BY skewData2.key2
| ) AS data2
|ON data1.key1 = data2.key2
|""".stripMargin

val (_, adaptivePlan) = runAdaptiveAndVerifyResult(sqlText)
val innerSmj = findTopLevelSortMergeJoin(adaptivePlan)
checkSkewJoin(innerSmj, 2, 0)
}
}
}

test("SPARK-30291: AQE should catch the exceptions when doing materialize") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") {
Expand Down