diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index b5d287ca7ac79..396c9c9d6b4e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.commons.io.FileUtils -import org.apache.spark.{MapOutputTrackerMaster, SparkEnv} +import org.apache.spark.{MapOutputStatistics, MapOutputTrackerMaster, SparkEnv} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ @@ -70,9 +70,9 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { size > conf.getConf(SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD) } - private def medianSize(sizes: Seq[Long]): Long = { - val numPartitions = sizes.length - val bytes = sizes.sorted + private def medianSize(stats: MapOutputStatistics): Long = { + val numPartitions = stats.bytesByPartitionId.length + val bytes = stats.bytesByPartitionId.sorted numPartitions match { case _ if (numPartitions % 2 == 0) => math.max((bytes(numPartitions / 2) + bytes(numPartitions / 2 - 1)) / 2, 1) @@ -163,16 +163,16 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { if supportedJoinTypes.contains(joinType) => assert(left.partitionsWithSizes.length == right.partitionsWithSizes.length) val numPartitions = left.partitionsWithSizes.length - // Use the median size of the actual (coalesced) partition sizes to detect skewed partitions. - val leftMedSize = medianSize(left.partitionsWithSizes.map(_._2)) - val rightMedSize = medianSize(right.partitionsWithSizes.map(_._2)) + // We use the median size of the original shuffle partitions to detect skewed partitions. + val leftMedSize = medianSize(left.mapStats) + val rightMedSize = medianSize(right.mapStats) logDebug( s""" |Optimizing skewed join. |Left side partitions size info: - |${getSizeInfo(leftMedSize, left.partitionsWithSizes.map(_._2))} + |${getSizeInfo(leftMedSize, left.mapStats.bytesByPartitionId)} |Right side partitions size info: - |${getSizeInfo(rightMedSize, right.partitionsWithSizes.map(_._2))} + |${getSizeInfo(rightMedSize, right.mapStats.bytesByPartitionId)} """.stripMargin) val canSplitLeft = canSplitLeftSide(joinType) val canSplitRight = canSplitRightSide(joinType) @@ -291,15 +291,17 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { private object ShuffleStage { def unapply(plan: SparkPlan): Option[ShuffleStageInfo] = plan match { case s: ShuffleQueryStageExec if s.mapStats.isDefined => - val sizes = s.mapStats.get.bytesByPartitionId + val mapStats = s.mapStats.get + val sizes = mapStats.bytesByPartitionId val partitions = sizes.zipWithIndex.map { case (size, i) => CoalescedPartitionSpec(i, i + 1) -> size } - Some(ShuffleStageInfo(s, partitions)) + Some(ShuffleStageInfo(s, mapStats, partitions)) case CustomShuffleReaderExec(s: ShuffleQueryStageExec, partitionSpecs) if s.mapStats.isDefined && partitionSpecs.nonEmpty => - val sizes = s.mapStats.get.bytesByPartitionId + val mapStats = s.mapStats.get + val sizes = mapStats.bytesByPartitionId val partitions = partitionSpecs.map { case spec @ CoalescedPartitionSpec(start, end) => var sum = 0L @@ -312,7 +314,7 @@ private object ShuffleStage { case other => throw new IllegalArgumentException( s"Expect CoalescedPartitionSpec but got $other") } - Some(ShuffleStageInfo(s, partitions)) + Some(ShuffleStageInfo(s, mapStats, partitions)) case _ => None } @@ -320,4 +322,5 @@ private object ShuffleStage { private case class ShuffleStageInfo( shuffleStage: ShuffleQueryStageExec, + mapStats: MapOutputStatistics, partitionsWithSizes: Seq[(CoalescedPartitionSpec, Long)])