Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.adaptive

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.{ShufflePartitionSpec, SparkPlan}
import org.apache.spark.sql.internal.SQLConf

/**
Expand All @@ -34,7 +34,7 @@ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPl
return plan
}
if (!plan.collectLeaves().forall(_.isInstanceOf[QueryStageExec])
|| plan.find(_.isInstanceOf[CustomShuffleReaderExec]).isDefined) {
|| plan.find(_.isInstanceOf[CustomShuffleReaderExec]).isDefined) {
// If not all leaf nodes are query stages, it's not safe to reduce the number of
// shuffle partitions, because we may break the assumption that all children of a spark plan
// have same number of output partitions.
Expand All @@ -56,20 +56,7 @@ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPl
// we should skip it when calculating the `partitionStartIndices`.
val validMetrics = shuffleStages.flatMap(_.mapStats)

// We may have different pre-shuffle partition numbers, don't reduce shuffle partition number
// in that case. For example when we union fully aggregated data (data is arranged to a single
// partition) and a result of a SortMergeJoin (multiple partitions).
val distinctNumPreShufflePartitions =
validMetrics.map(stats => stats.bytesByPartitionId.length).distinct
if (validMetrics.nonEmpty && distinctNumPreShufflePartitions.length == 1) {
// We fall back to Spark default parallelism if the minimum number of coalesced partitions
// is not set, so to avoid perf regressions compared to no coalescing.
val minPartitionNum = conf.getConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM)
.getOrElse(session.sparkContext.defaultParallelism)
val partitionSpecs = ShufflePartitionsUtil.coalescePartitions(
validMetrics.toArray,
advisoryTargetSize = conf.getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES),
minNumPartitions = minPartitionNum)
def updatePlan(partitionSpecs: Seq[ShufflePartitionSpec]): SparkPlan = {
// This transformation adds new nodes, so we must use `transformUp` here.
val stageIds = shuffleStages.map(_.id).toSet
plan.transformUp {
Expand All @@ -79,8 +66,29 @@ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPl
case stage: ShuffleQueryStageExec if stageIds.contains(stage.id) =>
CustomShuffleReaderExec(stage, partitionSpecs)
}
}

if (validMetrics.isEmpty) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if a query stage has multiple leaf shuffles, and only one of them has 0-partition input RDD. What shall we do?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's like coalescing one less shuffles and handled by the nonEmpty codes.

updatePlan(Nil)
Copy link
Member

@viirya viirya Jun 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment for the case of 0-partition?

} else {
plan
// We may have different pre-shuffle partition numbers, don't reduce shuffle partition
// number in that case. For example when we union fully aggregated data (data is arranged
// to a single partition) and a result of a SortMergeJoin (multiple partitions).
val distinctNumPreShufflePartitions =
validMetrics.map(stats => stats.bytesByPartitionId.length).distinct
if (distinctNumPreShufflePartitions.length == 1) {
// We fall back to Spark default parallelism if the minimum number of coalesced partitions
// is not set, so to avoid perf regressions compared to no coalescing.
val minPartitionNum = conf.getConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM)
.getOrElse(session.sparkContext.defaultParallelism)
val partitionSpecs = ShufflePartitionsUtil.coalescePartitions(
validMetrics.toArray,
advisoryTargetSize = conf.getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES),
minNumPartitions = minPartitionNum)
updatePlan(partitionSpecs)
} else {
plan
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,24 +211,26 @@ class AdaptiveQueryExecSuite
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val testDf = df1.where('a > 10).join(df2.where('b > 10), "id").groupBy('a).count()
checkAnswer(testDf, Seq())
assert(testDf.rdd.collectPartitions().length == 0)
val plan = testDf.queryExecution.executedPlan
assert(find(plan)(_.isInstanceOf[SortMergeJoinExec]).isDefined)
val coalescedReaders = collect(plan) {
case r: CustomShuffleReaderExec => r
}
assert(coalescedReaders.length == 2)
assert(coalescedReaders.length == 3, s"$plan")
coalescedReaders.foreach(r => assert(r.partitionSpecs.isEmpty))
}

withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") {
val testDf = df1.where('a > 10).join(df2.where('b > 10), "id").groupBy('a).count()
checkAnswer(testDf, Seq())
assert(testDf.rdd.collectPartitions().length == 0)
val plan = testDf.queryExecution.executedPlan
assert(find(plan)(_.isInstanceOf[BroadcastHashJoinExec]).isDefined)
val coalescedReaders = collect(plan) {
case r: CustomShuffleReaderExec => r
}
assert(coalescedReaders.length == 2, s"$plan")
assert(coalescedReaders.length == 3, s"$plan")
coalescedReaders.foreach(r => assert(r.partitionSpecs.isEmpty))
}
}
Expand Down