diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 28d1ccbf4db36..f836debcbafda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -752,7 +752,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Repartition(numPartitions, shuffle, child) => if (shuffle) { ShuffleExchangeExec(RoundRobinPartitioning(numPartitions), - planLater(child), canChangeNumPartitions = false) :: Nil + planLater(child), noUserSpecifiedNumPartition = false) :: Nil } else { execution.CoalesceExec(numPartitions, planLater(child)) :: Nil } @@ -786,7 +786,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.RangeExec(r) :: Nil case r: logical.RepartitionByExpression => exchange.ShuffleExchangeExec( - r.partitioning, planLater(r.child), canChangeNumPartitions = false) :: Nil + r.partitioning, planLater(r.child), noUserSpecifiedNumPartition = false) :: Nil case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil case r: LogicalRDD => RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala index 6684376b12539..31d1f34b64a65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala @@ -141,8 +141,10 @@ object OptimizeLocalShuffleReader { } def canUseLocalShuffleReader(plan: SparkPlan): Boolean = plan match { - case s: ShuffleQueryStageExec => s.shuffle.canChangeNumPartitions - case CustomShuffleReaderExec(s: ShuffleQueryStageExec, _, _) => s.shuffle.canChangeNumPartitions + case s: ShuffleQueryStageExec => + s.shuffle.canChangeNumPartitions && s.mapStats.isDefined + case CustomShuffleReaderExec(s: ShuffleQueryStageExec, _, _) => + s.shuffle.canChangeNumPartitions && s.mapStats.isDefined case _ => false } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index b7da78cb0eefb..24c736951fdc4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -83,7 +83,12 @@ trait ShuffleExchangeLike extends Exchange { case class ShuffleExchangeExec( override val outputPartitioning: Partitioning, child: SparkPlan, - canChangeNumPartitions: Boolean = true) extends ShuffleExchangeLike { + noUserSpecifiedNumPartition: Boolean = true) extends ShuffleExchangeLike { + + // If users specify the num partitions via APIs like `repartition`, we shouldn't change it. + // For `SinglePartition`, it requires exactly one partition and we can't change it either. + def canChangeNumPartitions: Boolean = + noUserSpecifiedNumPartition && outputPartitioning != SinglePartition private lazy val writeMetrics = SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)