Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, AttributeReference, Expression, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, PartitioningCollection, UnknownPartitioning}

/**
* A trait that provides functionality to handle aliases in the `outputExpressions`.
Expand All @@ -44,7 +44,7 @@ trait AliasAwareOutputExpression extends UnaryExecNode {
*/
trait AliasAwareOutputPartitioning extends AliasAwareOutputExpression {
final override def outputPartitioning: Partitioning = {
if (hasAlias) {
val normalizedOutputPartitioning = if (hasAlias) {
child.outputPartitioning match {
case e: Expression =>
normalizeExpression(e).asInstanceOf[Partitioning]
Expand All @@ -53,6 +53,24 @@ trait AliasAwareOutputPartitioning extends AliasAwareOutputExpression {
} else {
child.outputPartitioning
}

flattenPartitioning(normalizedOutputPartitioning).filter {
case hashPartitioning: HashPartitioning => hashPartitioning.references.subsetOf(outputSet)
case _ => true
} match {
case Seq() => UnknownPartitioning(child.outputPartitioning.numPartitions)
case Seq(singlePartitioning) => singlePartitioning
case seqWithMultiplePartitionings => PartitioningCollection(seqWithMultiplePartitionings)
}
}

private def flattenPartitioning(partitioning: Partitioning): Seq[Partitioning] = {
partitioning match {
case PartitioningCollection(childPartitionings) =>
childPartitionings.flatMap(flattenPartitioning)
case rest =>
rest +: Nil
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -921,10 +921,10 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {

val projects = planned.collect { case p: ProjectExec => p }
assert(projects.exists(_.outputPartitioning match {
case PartitioningCollection(Seq(HashPartitioning(Seq(k1: AttributeReference), _),
HashPartitioning(Seq(k2: AttributeReference), _))) if k1.name == "t1id" =>
case HashPartitioning(Seq(k1: AttributeReference), _) if k1.name == "t1id" =>
true
case _ => false
case _ =>
false
}))
}
}
Expand Down Expand Up @@ -1008,17 +1008,11 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {

val projects = planned.collect { case p: ProjectExec => p }
assert(projects.exists(_.outputPartitioning match {
case PartitioningCollection(Seq(HashPartitioning(Seq(Multiply(ar1, _, _)), _),
HashPartitioning(Seq(Multiply(ar2, _, _)), _))) =>
Seq(ar1, ar2) match {
case Seq(ar1: AttributeReference, ar2: AttributeReference) =>
ar1.name == "t1id" && ar2.name == "id2"
case _ =>
false
}
case _ => false
case HashPartitioning(Seq(Multiply(ar1: AttributeReference, _, _)), _) =>
ar1.name == "t1id"
case _ =>
false
}))

}
}
}
Expand Down Expand Up @@ -1234,6 +1228,40 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
val numPartitions = range.rdd.getNumPartitions
assert(numPartitions == 0)
}

test("SPARK-33758: Prune unnecessary output partitioning") {
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
withTempView("t1", "t2") {
spark.range(10).repartition($"id").createTempView("t1")
spark.range(20).repartition($"id").createTempView("t2")
val planned = sql(
"""
| SELECT t1.id as t1id, t2.id as t2id
| FROM t1, t2
| WHERE t1.id = t2.id
""".stripMargin).queryExecution.executedPlan

assert(planned.outputPartitioning match {
case PartitioningCollection(Seq(HashPartitioning(Seq(k1: AttributeReference), _),
HashPartitioning(Seq(k2: AttributeReference), _))) =>
k1.name == "t1id" && k2.name == "t2id"
})

val planned2 = sql(
"""
| SELECT t1.id as t1id
| FROM t1, t2
| WHERE t1.id = t2.id
""".stripMargin).queryExecution.executedPlan
assert(planned2.outputPartitioning match {
case HashPartitioning(Seq(k1: AttributeReference), _) if k1.name == "t1id" =>
true
})
}
}
}
}

// Used for unit-testing EnsureRequirements
Expand Down