diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c4d5a2b9b3f0a..5500941936442 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -703,7 +703,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Repartition(numPartitions, shuffle, child) => if (shuffle) { - ShuffleExchangeExec(RoundRobinPartitioning(numPartitions), planLater(child)) :: Nil + ShuffleExchangeExec(RoundRobinPartitioning(numPartitions), + planLater(child), canChangeNumPartitions = false) :: Nil } else { execution.CoalesceExec(numPartitions, planLater(child)) :: Nil } @@ -736,7 +737,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case r: logical.Range => execution.RangeExec(r) :: Nil case r: logical.RepartitionByExpression => - exchange.ShuffleExchangeExec(r.partitioning, planLater(r.child)) :: Nil + exchange.ShuffleExchangeExec( + r.partitioning, planLater(r.child), canChangeNumPartitions = false) :: Nil case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil case r: LogicalRDD => RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala index d93eb76b9fbc4..78923433eaab9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ReduceNumShufflePartitions.scala @@ -61,12 +61,18 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] { // If not all leaf nodes are query stages, it's not safe to reduce the number of // shuffle partitions, because we may break the assumption that all children of a spark plan // have same number of output partitions. + return plan + } + + val shuffleStages = plan.collect { + case stage: ShuffleQueryStageExec => stage + case ReusedQueryStageExec(_, stage: ShuffleQueryStageExec, _) => stage + } + // ShuffleExchanges introduced by repartition do not support changing the number of partitions. + // We change the number of partitions in the stage only if all the ShuffleExchanges support it. + if (!shuffleStages.forall(_.plan.canChangeNumPartitions)) { plan } else { - val shuffleStages = plan.collect { - case stage: ShuffleQueryStageExec => stage - case ReusedQueryStageExec(_, stage: ShuffleQueryStageExec, _) => stage - } val shuffleMetrics = shuffleStages.map { stage => val metricsFuture = stage.mapOutputStatisticsFuture assert(metricsFuture.isCompleted, "ShuffleQueryStageExec should already be ready") @@ -76,12 +82,7 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] { // `ShuffleQueryStageExec` gives null mapOutputStatistics when the input RDD has 0 partitions, // we should skip it when calculating the `partitionStartIndices`. val validMetrics = shuffleMetrics.filter(_ != null) - // We may get different pre-shuffle partition number if user calls repartition manually. - // We don't reduce shuffle partition number in that case. - val distinctNumPreShufflePartitions = - validMetrics.map(stats => stats.bytesByPartitionId.length).distinct - - if (validMetrics.nonEmpty && distinctNumPreShufflePartitions.length == 1) { + if (validMetrics.nonEmpty) { val partitionStartIndices = estimatePartitionStartIndices(validMetrics.toArray) // This transformation adds new nodes, so we must use `transformUp` here. plan.transformUp { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 8184baf50b042..079fb006ccb50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -94,7 +94,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { val defaultPartitioning = distribution.createPartitioning(targetNumPartitions) child match { // If child is an exchange, we replace it with a new one having defaultPartitioning. - case ShuffleExchangeExec(_, c) => ShuffleExchangeExec(defaultPartitioning, c) + case ShuffleExchangeExec(_, c, _) => ShuffleExchangeExec(defaultPartitioning, c) case _ => ShuffleExchangeExec(defaultPartitioning, child) } } @@ -191,7 +191,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = plan.transformUp { // TODO: remove this after we create a physical operator for `RepartitionByExpression`. - case operator @ ShuffleExchangeExec(upper: HashPartitioning, child) => + case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) => child.outputPartitioning match { case lower: HashPartitioning if upper.semanticEquals(lower) => child case _ => operator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 5d0208f1ecc46..fec05a76b4516 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -43,7 +43,8 @@ import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordCo */ case class ShuffleExchangeExec( override val outputPartitioning: Partitioning, - child: SparkPlan) extends Exchange { + child: SparkPlan, + canChangeNumPartitions: Boolean = true) extends Exchange { // NOTE: coordinator can be null after serialization/deserialization, // e.g. it can be null on the Executor side diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index efd5db1c5b6c4..4b08a4b0d1a0f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1365,7 +1365,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val agg = cp.groupBy('id % 2).agg(count('id)) agg.queryExecution.executedPlan.collectFirst { - case ShuffleExchangeExec(_, _: RDDScanExec) => + case ShuffleExchangeExec(_, _: RDDScanExec, _) => case BroadcastExchangeExec(_, _: RDDScanExec) => }.foreach { _ => fail( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala index 20fed07d38726..35c33a7157d38 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala @@ -574,22 +574,17 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA withSparkSession(test, 4, None) } - test("Union two datasets with different pre-shuffle partition number") { + test("Do not reduce the number of shuffle partition for repartition") { val test: SparkSession => Unit = { spark: SparkSession => - val dataset1 = spark.range(3) - val dataset2 = spark.range(3) - - val resultDf = dataset1.repartition(2, dataset1.col("id")) - .union(dataset2.repartition(3, dataset2.col("id"))).toDF() + val ds = spark.range(3) + val resultDf = ds.repartition(2, ds.col("id")).toDF() checkAnswer(resultDf, - Seq((0), (0), (1), (1), (2), (2)).map(i => Row(i))) + Seq(0, 1, 2).map(i => Row(i))) val finalPlan = resultDf.queryExecution.executedPlan .asInstanceOf[AdaptiveSparkPlanExec].executedPlan - // As the pre-shuffle partition number are different, we will skip reducing - // the shuffle partition numbers. assert(finalPlan.collect { case p: CoalescedShuffleReaderExec => p }.length == 0) } - withSparkSession(test, 100, None) + withSparkSession(test, 200, None) } }