diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala index 3cbe1654ea2c..23a9527a1b34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/AliasAwareOutputExpression.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, AttributeReference, Expression, NamedExpression, SortOrder} -import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, PartitioningCollection, UnknownPartitioning} /** * A trait that provides functionality to handle aliases in the `outputExpressions`. @@ -44,7 +44,7 @@ trait AliasAwareOutputExpression extends UnaryExecNode { */ trait AliasAwareOutputPartitioning extends AliasAwareOutputExpression { final override def outputPartitioning: Partitioning = { - if (hasAlias) { + val normalizedOutputPartitioning = if (hasAlias) { child.outputPartitioning match { case e: Expression => normalizeExpression(e).asInstanceOf[Partitioning] @@ -53,6 +53,24 @@ trait AliasAwareOutputPartitioning extends AliasAwareOutputExpression { } else { child.outputPartitioning } + + flattenPartitioning(normalizedOutputPartitioning).filter { + case hashPartitioning: HashPartitioning => hashPartitioning.references.subsetOf(outputSet) + case _ => true + } match { + case Seq() => UnknownPartitioning(child.outputPartitioning.numPartitions) + case Seq(singlePartitioning) => singlePartitioning + case seqWithMultiplePartitionings => PartitioningCollection(seqWithMultiplePartitionings) + } + } + + private def flattenPartitioning(partitioning: Partitioning): Seq[Partitioning] = { + partitioning match { + case PartitioningCollection(childPartitionings) => + childPartitionings.flatMap(flattenPartitioning) + case rest => + rest +: Nil + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 4e01d1c06f64..924776ae3ae6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -921,10 +921,10 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { val projects = planned.collect { case p: ProjectExec => p } assert(projects.exists(_.outputPartitioning match { - case PartitioningCollection(Seq(HashPartitioning(Seq(k1: AttributeReference), _), - HashPartitioning(Seq(k2: AttributeReference), _))) if k1.name == "t1id" => + case HashPartitioning(Seq(k1: AttributeReference), _) if k1.name == "t1id" => true - case _ => false + case _ => + false })) } } @@ -1008,17 +1008,11 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { val projects = planned.collect { case p: ProjectExec => p } assert(projects.exists(_.outputPartitioning match { - case PartitioningCollection(Seq(HashPartitioning(Seq(Multiply(ar1, _, _)), _), - HashPartitioning(Seq(Multiply(ar2, _, _)), _))) => - Seq(ar1, ar2) match { - case Seq(ar1: AttributeReference, ar2: AttributeReference) => - ar1.name == "t1id" && ar2.name == "id2" - case _ => - false - } - case _ => false + case HashPartitioning(Seq(Multiply(ar1: AttributeReference, _, _)), _) => + ar1.name == "t1id" + case _ => + false })) - } } } @@ -1234,6 +1228,40 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { val numPartitions = range.rdd.getNumPartitions assert(numPartitions == 0) } + + test("SPARK-33758: Prune unnecessary output partitioning") { + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + withTempView("t1", "t2") { + spark.range(10).repartition($"id").createTempView("t1") + spark.range(20).repartition($"id").createTempView("t2") + val planned = sql( + """ + | SELECT t1.id as t1id, t2.id as t2id + | FROM t1, t2 + | WHERE t1.id = t2.id + """.stripMargin).queryExecution.executedPlan + + assert(planned.outputPartitioning match { + case PartitioningCollection(Seq(HashPartitioning(Seq(k1: AttributeReference), _), + HashPartitioning(Seq(k2: AttributeReference), _))) => + k1.name == "t1id" && k2.name == "t2id" + }) + + val planned2 = sql( + """ + | SELECT t1.id as t1id + | FROM t1, t2 + | WHERE t1.id = t2.id + """.stripMargin).queryExecution.executedPlan + assert(planned2.outputPartitioning match { + case HashPartitioning(Seq(k1: AttributeReference), _) if k1.name == "t1id" => + true + }) + } + } + } } // Used for unit-testing EnsureRequirements