diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 91fb77574a0ca..8c111aa750809 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -1017,7 +1017,16 @@ case class RepartitionByExpression( child: LogicalPlan, optNumPartitions: Option[Int]) extends RepartitionOperation { - val numPartitions = optNumPartitions.getOrElse(SQLConf.get.numShufflePartitions) + val numPartitions = if (optNumPartitions.nonEmpty) { + optNumPartitions.get + } else { + if (partitionExpressions.forall(_.foldable)) { + 1 + } else { + SQLConf.get.numShufflePartitions + } + } + require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.") val partitioning: Partitioning = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index ebfe8bdd7a749..112b1a7210cb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Partial} import org.apache.spark.sql.catalyst.optimizer.{ConvertToLocalRelation, NestedColumnAliasingSuite} -import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.catalyst.plans.logical.{Project, RepartitionByExpression} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} @@ -3732,6 +3732,19 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark checkAnswer(sql("SELECT s LIKE 'm@@ca' ESCAPE '@' FROM df"), Row(true)) } } + + test("limit partition num to 1 when distributing by foldable expressions") { + withSQLConf((SQLConf.SHUFFLE_PARTITIONS.key, "5")) { + Seq(1, "1, 2", null, "version()").foreach { expr => + val plan = sql(s"select * from values (1), (2), (3) t(a) distribute by $expr") + .queryExecution.optimizedPlan + val res = plan.collect { + case r: RepartitionByExpression if r.numPartitions == 1 => true + } + assert(res.nonEmpty) + } + } + } } case class Foo(bar: Option[String])