Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -121,22 +121,6 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
stage.shuffle.shuffleDependency.rdd.partitions.length
}

private def getShuffleQueryStage(plan : SparkPlan): Option[ShuffleQueryStageExec] =
plan match {
case stage: ShuffleQueryStageExec => Some(stage)
case SortExec(_, _, s: ShuffleQueryStageExec, _) =>
Some(s)
case _ => None
}

private def reOptimizeChild(
skewedReader: SkewedPartitionReaderExec,
child: SparkPlan): SparkPlan = child match {
case sort @ SortExec(_, _, s: ShuffleQueryStageExec, _) =>
sort.copy(child = skewedReader)
case _: ShuffleQueryStageExec => skewedReader
}

private def getSizeInfo(medianSize: Long, maxSize: Long): String = {
s"median size: $medianSize, max size: ${maxSize}"
}
Expand All @@ -153,11 +137,10 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
* 4. Finally union the above 3 split smjs and the origin smj.
*/
def optimizeSkewJoin(plan: SparkPlan): SparkPlan = plan.transformUp {
case smj @ SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, leftPlan, rightPlan)
if (getShuffleQueryStage(leftPlan).nonEmpty && getShuffleQueryStage(rightPlan).nonEmpty) &&
supportedJoinTypes.contains(joinType) =>
val left = getShuffleQueryStage(leftPlan).get
val right = getShuffleQueryStage(rightPlan).get
case smj @ SortMergeJoinExec(leftKeys, rightKeys, joinType, condition,
s1 @ SortExec(_, _, left: ShuffleQueryStageExec, _),
s2 @ SortExec(_, _, right: ShuffleQueryStageExec, _), _)
if (supportedJoinTypes.contains(joinType)) =>
val leftStats = getStatistics(left)
val rightStats = getStatistics(right)
val numPartitions = leftStats.bytesByPartitionId.length
Expand Down Expand Up @@ -209,20 +192,17 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
left, partitionId, leftMapIdStartIndices(i), leftEndMapId)
val rightSkewedReader = SkewedPartitionReaderExec(right, partitionId,
rightMapIdStartIndices(j), rightEndMapId)
val skewedLeft = reOptimizeChild(leftSkewedReader, leftPlan)
val skewedRight = reOptimizeChild(rightSkewedReader, rightPlan)
subJoins += SortMergeJoinExec(leftKeys, rightKeys, joinType, condition,
skewedLeft, skewedRight)
s1.copy(child = leftSkewedReader), s2.copy(child = rightSkewedReader), true)
}
}
}
logDebug(s"number of skewed partitions is ${skewedPartitions.size}")
if (skewedPartitions.nonEmpty) {
val optimizedSmj = smj.transformUp {
case shuffleStage: ShuffleQueryStageExec if shuffleStage.id == left.id ||
shuffleStage.id == right.id =>
PartialShuffleReaderExec(shuffleStage, skewedPartitions.toSet)
}
val optimizedSmj = smj.copy(
left = s1.copy(child = PartialShuffleReaderExec(left, skewedPartitions.toSet)),
right = s2.copy(child = PartialShuffleReaderExec(right, skewedPartitions.toSet)),
isPartial = true)
subJoins += optimizedSmj
UnionExec(subJoins)
} else {
Expand All @@ -236,8 +216,6 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
}

def collectShuffleStages(plan: SparkPlan): Seq[ShuffleQueryStageExec] = plan match {
case _: LocalShuffleReaderExec => Nil
case _: CoalescedShuffleReaderExec => Nil
case stage: ShuffleQueryStageExec => Seq(stage)
case _ => plan.children.flatMap(collectShuffleStages)
}
Expand All @@ -247,9 +225,9 @@ case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
if (shuffleStages.length == 2) {
// When multi table join, there will be too many complex combination to consider.
// Currently we only handle 2 table join like following two use cases.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: only one case now

// SMJ SMJ
// Sort Shuffle
// Shuffle or Shuffle
// SMJ
// Sort
// Shuffle
// Sort
// Shuffle
val optimizePlan = optimizeSkewJoin(plan)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,11 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition,
left, right)

case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) =>
case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right, isPartial) =>
val (reorderedLeftKeys, reorderedRightKeys) =
reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right)
SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition,
left, right, isPartial)

case other => other
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ case class SortMergeJoinExec(
joinType: JoinType,
condition: Option[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryExecNode with CodegenSupport {
right: SparkPlan,
isPartial: Boolean = false) extends BinaryExecNode with CodegenSupport {

override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
Expand Down Expand Up @@ -96,15 +97,8 @@ case class SortMergeJoinExec(
s"${getClass.getSimpleName} should not take $x as the JoinType")
}

private def containSkewedReader(plan: SparkPlan): Boolean = plan match {
case s: SkewedPartitionReaderExec => true
case p: PartialShuffleReaderExec => true
case s: SortExec => containSkewedReader(s.child)
case _ => false
}

override def requiredChildDistribution: Seq[Distribution] = {
if (containSkewedReader(left)) {
if (isPartial) {
UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
} else {
HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {

val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec)
outputPlan match {
case SortMergeJoinExec(leftKeys, rightKeys, _, _, _, _) =>
case SortMergeJoinExec(leftKeys, rightKeys, _, _, _, _, _) =>
assert(leftKeys == Seq(exprA, exprA))
assert(rightKeys == Seq(exprB, exprC))
case _ => fail()
Expand All @@ -766,7 +766,8 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
SortExec(_, _,
ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _, _), _),
SortExec(_, _,
ShuffleExchangeExec(HashPartitioning(rightPartitioningExpressions, _), _, _), _)) =>
ShuffleExchangeExec(HashPartitioning(rightPartitioningExpressions, _),
_, _), _), _) =>
assert(leftKeys === smjExec.leftKeys)
assert(rightKeys === smjExec.rightKeys)
assert(leftKeys === leftPartitioningExpressions)
Expand Down