diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index b92f34680f66..0b9c46936b24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -97,8 +97,10 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { /** A sequence of rules that will be applied in order to the physical plan before execution. */ protected def preparations: Seq[Rule[SparkPlan]] = Seq( PlanSubqueries(sparkSession), - RemoveRedundantSorts(sparkSession.sessionState.conf), EnsureRequirements(sparkSession.sessionState.conf), + // `RemoveRedundantSorts` needs to be added before `EnsureRequirements` to guarantee the same + // number of partitions when instantiating PartitioningCollection. + RemoveRedundantSorts(sparkSession.sessionState.conf), CollapseCodegenStages(sparkSession.sessionState.conf), ReuseExchange(sparkSession.sessionState.conf), ReuseSubquery(sparkSession.sessionState.conf)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 7646f9613efb..28addf602511 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -91,7 +91,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ def longMetric(name: String): SQLMetric = metrics(name) // TODO: Move to `DistributedPlan` - /** Specifies how data is partitioned across different nodes in the cluster. */ + /** + * Specifies how data is partitioned across different nodes in the cluster. + * Note this method may fail if it is invoked before `EnsureRequirements` is applied + * since `PartitioningCollection` requires all its partitionings to have + * the same number of partitions. + */ def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH! /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala index f7987e293b3f..b82e5cb77c07 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.catalyst.plans.physical.{RangePartitioning, UnknownPartitioning} +import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -99,4 +101,29 @@ class RemoveRedundantSortsSuite } } } + + test("SPARK-33472: shuffled join with different left and right side partition numbers") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withTempView("t1", "t2") { + spark.range(0, 100, 1, 2).select('id as "key").createOrReplaceTempView("t1") + (0 to 100).toDF("key").createOrReplaceTempView("t2") + + val query = """ + |SELECT t1.key + |FROM t1 JOIN t2 ON t1.key = t2.key + |WHERE t1.key > 10 AND t2.key < 50 + |ORDER BY t1.key ASC + """.stripMargin + + val df = sql(query) + val sparkPlan = df.queryExecution.sparkPlan + val join = sparkPlan.collect { case j: SortMergeJoinExec => j }.head + val leftPartitioning = join.left.outputPartitioning + assert(leftPartitioning.isInstanceOf[RangePartitioning]) + assert(leftPartitioning.numPartitions == 2) + assert(join.right.outputPartitioning == UnknownPartitioning(0)) + checkSorts(query, 3, 3) + } + } + } }