diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala index ad828006e3315..74b7fbd317fc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewedJoin.scala @@ -121,22 +121,6 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { stage.shuffle.shuffleDependency.rdd.partitions.length } - private def getShuffleQueryStage(plan : SparkPlan): Option[ShuffleQueryStageExec] = - plan match { - case stage: ShuffleQueryStageExec => Some(stage) - case SortExec(_, _, s: ShuffleQueryStageExec, _) => - Some(s) - case _ => None - } - - private def reOptimizeChild( - skewedReader: SkewedPartitionReaderExec, - child: SparkPlan): SparkPlan = child match { - case sort @ SortExec(_, _, s: ShuffleQueryStageExec, _) => - sort.copy(child = skewedReader) - case _: ShuffleQueryStageExec => skewedReader - } - private def getSizeInfo(medianSize: Long, maxSize: Long): String = { s"median size: $medianSize, max size: ${maxSize}" } @@ -153,11 +137,10 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { * 4. Finally union the above 3 split smjs and the origin smj. */ def optimizeSkewJoin(plan: SparkPlan): SparkPlan = plan.transformUp { - case smj @ SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, leftPlan, rightPlan) - if (getShuffleQueryStage(leftPlan).nonEmpty && getShuffleQueryStage(rightPlan).nonEmpty) && - supportedJoinTypes.contains(joinType) => - val left = getShuffleQueryStage(leftPlan).get - val right = getShuffleQueryStage(rightPlan).get + case smj @ SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, + s1 @ SortExec(_, _, left: ShuffleQueryStageExec, _), + s2 @ SortExec(_, _, right: ShuffleQueryStageExec, _), _) + if (supportedJoinTypes.contains(joinType)) => val leftStats = getStatistics(left) val rightStats = getStatistics(right) val numPartitions = leftStats.bytesByPartitionId.length @@ -209,20 +192,17 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { left, partitionId, leftMapIdStartIndices(i), leftEndMapId) val rightSkewedReader = SkewedPartitionReaderExec(right, partitionId, rightMapIdStartIndices(j), rightEndMapId) - val skewedLeft = reOptimizeChild(leftSkewedReader, leftPlan) - val skewedRight = reOptimizeChild(rightSkewedReader, rightPlan) subJoins += SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, - skewedLeft, skewedRight) + s1.copy(child = leftSkewedReader), s2.copy(child = rightSkewedReader), true) } } } logDebug(s"number of skewed partitions is ${skewedPartitions.size}") if (skewedPartitions.nonEmpty) { - val optimizedSmj = smj.transformUp { - case shuffleStage: ShuffleQueryStageExec if shuffleStage.id == left.id || - shuffleStage.id == right.id => - PartialShuffleReaderExec(shuffleStage, skewedPartitions.toSet) - } + val optimizedSmj = smj.copy( + left = s1.copy(child = PartialShuffleReaderExec(left, skewedPartitions.toSet)), + right = s2.copy(child = PartialShuffleReaderExec(right, skewedPartitions.toSet)), + isPartial = true) subJoins += optimizedSmj UnionExec(subJoins) } else { @@ -236,8 +216,6 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { } def collectShuffleStages(plan: SparkPlan): Seq[ShuffleQueryStageExec] = plan match { - case _: LocalShuffleReaderExec => Nil - case _: CoalescedShuffleReaderExec => Nil case stage: ShuffleQueryStageExec => Seq(stage) case _ => plan.children.flatMap(collectShuffleStages) } @@ -247,9 +225,9 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] { if (shuffleStages.length == 2) { // When multi table join, there will be too many complex combination to consider. // Currently we only handle 2 table join like following two use cases. - // SMJ SMJ - // Sort Shuffle - // Shuffle or Shuffle + // SMJ + // Sort + // Shuffle // Sort // Shuffle val optimizePlan = optimizeSkewJoin(plan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 033404ccac44d..68cf2200bf73a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -207,10 +207,11 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, left, right) - case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) => + case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right, isPartial) => val (reorderedLeftKeys, reorderedRightKeys) = reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) - SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right) + SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, + left, right, isPartial) case other => other } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 985e1db2736fc..6384aed6a78e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -41,7 +41,8 @@ case class SortMergeJoinExec( joinType: JoinType, condition: Option[Expression], left: SparkPlan, - right: SparkPlan) extends BinaryExecNode with CodegenSupport { + right: SparkPlan, + isPartial: Boolean = false) extends BinaryExecNode with CodegenSupport { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -96,15 +97,8 @@ case class SortMergeJoinExec( s"${getClass.getSimpleName} should not take $x as the JoinType") } - private def containSkewedReader(plan: SparkPlan): Boolean = plan match { - case s: SkewedPartitionReaderExec => true - case p: PartialShuffleReaderExec => true - case s: SortExec => containSkewedReader(s.child) - case _ => false - } - override def requiredChildDistribution: Seq[Distribution] = { - if (containSkewedReader(left)) { + if (isPartial) { UnspecifiedDistribution :: UnspecifiedDistribution :: Nil } else { HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 563c42901ecaa..a10db54855c8a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -742,7 +742,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec) outputPlan match { - case SortMergeJoinExec(leftKeys, rightKeys, _, _, _, _) => + case SortMergeJoinExec(leftKeys, rightKeys, _, _, _, _, _) => assert(leftKeys == Seq(exprA, exprA)) assert(rightKeys == Seq(exprB, exprC)) case _ => fail() @@ -766,7 +766,8 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { SortExec(_, _, ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _, _), _), SortExec(_, _, - ShuffleExchangeExec(HashPartitioning(rightPartitioningExpressions, _), _, _), _)) => + ShuffleExchangeExec(HashPartitioning(rightPartitioningExpressions, _), + _, _), _), _) => assert(leftKeys === smjExec.leftKeys) assert(rightKeys === smjExec.rightKeys) assert(leftKeys === leftPartitioningExpressions)