Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.mutable

import org.apache.commons.io.FileUtils

import org.apache.spark.{MapOutputTrackerMaster, SparkEnv}
import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
Expand Down Expand Up @@ -70,9 +70,9 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
size > conf.getConf(SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD)
}

private def medianSize(sizes: Seq[Long]): Long = {
val numPartitions = sizes.length
val bytes = sizes.sorted
private def medianSize(stats: MapOutputStatistics): Long = {
val numPartitions = stats.bytesByPartitionId.length
val bytes = stats.bytesByPartitionId.sorted
numPartitions match {
case _ if (numPartitions % 2 == 0) =>
math.max((bytes(numPartitions / 2) + bytes(numPartitions / 2 - 1)) / 2, 1)
Expand Down Expand Up @@ -163,16 +163,16 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
if supportedJoinTypes.contains(joinType) =>
assert(left.partitionsWithSizes.length == right.partitionsWithSizes.length)
val numPartitions = left.partitionsWithSizes.length
// Use the median size of the actual (coalesced) partition sizes to detect skewed partitions.
val leftMedSize = medianSize(left.partitionsWithSizes.map(_._2))
val rightMedSize = medianSize(right.partitionsWithSizes.map(_._2))
// We use the median size of the original shuffle partitions to detect skewed partitions.
val leftMedSize = medianSize(left.mapStats)
val rightMedSize = medianSize(right.mapStats)
logDebug(
s"""
|Optimizing skewed join.
|Left side partitions size info:
|${getSizeInfo(leftMedSize, left.partitionsWithSizes.map(_._2))}
|${getSizeInfo(leftMedSize, left.mapStats.bytesByPartitionId)}
|Right side partitions size info:
|${getSizeInfo(rightMedSize, right.partitionsWithSizes.map(_._2))}
|${getSizeInfo(rightMedSize, right.mapStats.bytesByPartitionId)}
""".stripMargin)
val canSplitLeft = canSplitLeftSide(joinType)
val canSplitRight = canSplitRightSide(joinType)
Expand Down Expand Up @@ -291,15 +291,17 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
private object ShuffleStage {
def unapply(plan: SparkPlan): Option[ShuffleStageInfo] = plan match {
case s: ShuffleQueryStageExec if s.mapStats.isDefined =>
val sizes = s.mapStats.get.bytesByPartitionId
val mapStats = s.mapStats.get
val sizes = mapStats.bytesByPartitionId
val partitions = sizes.zipWithIndex.map {
case (size, i) => CoalescedPartitionSpec(i, i + 1) -> size
}
Some(ShuffleStageInfo(s, partitions))
Some(ShuffleStageInfo(s, mapStats, partitions))

case CustomShuffleReaderExec(s: ShuffleQueryStageExec, partitionSpecs)
if s.mapStats.isDefined && partitionSpecs.nonEmpty =>
val sizes = s.mapStats.get.bytesByPartitionId
val mapStats = s.mapStats.get
val sizes = mapStats.bytesByPartitionId
val partitions = partitionSpecs.map {
case spec @ CoalescedPartitionSpec(start, end) =>
var sum = 0L
Expand All @@ -312,12 +314,13 @@ private object ShuffleStage {
case other => throw new IllegalArgumentException(
s"Expect CoalescedPartitionSpec but got $other")
}
Some(ShuffleStageInfo(s, partitions))
Some(ShuffleStageInfo(s, mapStats, partitions))

case _ => None
}
}

private case class ShuffleStageInfo(
shuffleStage: ShuffleQueryStageExec,
mapStats: MapOutputStatistics,
partitionsWithSizes: Seq[(CoalescedPartitionSpec, Long)])