diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 0b4dd76c7d860..997cf92449c68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, SparkPlan, WholeStageCodegenExec} import org.apache.spark.sql.execution.vectorized._ import org.apache.spark.sql.types._ @@ -169,8 +169,8 @@ case class InMemoryTableScanExec( // But the cached version could alias output, so we need to replace output. override def outputPartitioning: Partitioning = { relation.cachedPlan.outputPartitioning match { - case h: HashPartitioning => updateAttribute(h).asInstanceOf[HashPartitioning] - case _ => relation.cachedPlan.outputPartitioning + case e: Expression => updateAttribute(e).asInstanceOf[Partitioning] + case other => other } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index 09f79a2de0ba0..1a5b7599bb7d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -24,7 +24,7 @@ import org.apache.spark.broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SortOrder} -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode} import org.apache.spark.sql.internal.SQLConf @@ -70,7 +70,7 @@ case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchan } override def outputPartitioning: Partitioning = child.outputPartitioning match { - case h: HashPartitioning => h.copy(expressions = h.expressions.map(updateAttr)) + case e: Expression => updateAttr(e).asInstanceOf[Partitioning] case other => other } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index ed0ff1be476c7..7c7adf0d362c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{execution, Row} +import org.apache.spark.sql.{execution, DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort, Union} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ @@ -686,6 +686,66 @@ class PlannerSuite extends SharedSQLContext { Range(1, 2, 1, 1))) df.queryExecution.executedPlan.execute() } + + test("SPARK-24556: always rewrite output partitioning in ReusedExchangeExec " + + "and InMemoryTableScanExec") { + def checkOutputPartitioningRewrite( + plans: Seq[SparkPlan], + expectedPartitioningClass: Class[_]): Unit = { + assert(plans.size == 1) + val plan = plans.head + val partitioning = plan.outputPartitioning + assert(partitioning.getClass == expectedPartitioningClass) + val partitionedAttrs = partitioning.asInstanceOf[Expression].references + assert(partitionedAttrs.subsetOf(plan.outputSet)) + } + + def checkReusedExchangeOutputPartitioningRewrite( + df: DataFrame, + expectedPartitioningClass: Class[_]): Unit = { + val reusedExchange = df.queryExecution.executedPlan.collect { + case r: ReusedExchangeExec => r + } + checkOutputPartitioningRewrite(reusedExchange, expectedPartitioningClass) + } + + def checkInMemoryTableScanOutputPartitioningRewrite( + df: DataFrame, + expectedPartitioningClass: Class[_]): Unit = { + val inMemoryScan = df.queryExecution.executedPlan.collect { + case m: InMemoryTableScanExec => m + } + checkOutputPartitioningRewrite(inMemoryScan, expectedPartitioningClass) + } + + // ReusedExchange is HashPartitioning + val df1 = Seq(1 -> "a").toDF("i", "j").repartition($"i") + val df2 = Seq(1 -> "a").toDF("i", "j").repartition($"i") + checkReusedExchangeOutputPartitioningRewrite(df1.union(df2), classOf[HashPartitioning]) + + // ReusedExchange is RangePartitioning + val df3 = Seq(1 -> "a").toDF("i", "j").orderBy($"i") + val df4 = Seq(1 -> "a").toDF("i", "j").orderBy($"i") + checkReusedExchangeOutputPartitioningRewrite(df3.union(df4), classOf[RangePartitioning]) + + // InMemoryTableScan is HashPartitioning + Seq(1 -> "a").toDF("i", "j").repartition($"i").persist() + checkInMemoryTableScanOutputPartitioningRewrite( + Seq(1 -> "a").toDF("i", "j").repartition($"i"), classOf[HashPartitioning]) + + // InMemoryTableScan is RangePartitioning + spark.range(1, 100, 1, 10).toDF().persist() + checkInMemoryTableScanOutputPartitioningRewrite( + spark.range(1, 100, 1, 10).toDF(), classOf[RangePartitioning]) + + // InMemoryTableScan is PartitioningCollection + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + Seq(1 -> "a").toDF("i", "j").join(Seq(1 -> "a").toDF("m", "n"), $"i" === $"m").persist() + checkInMemoryTableScanOutputPartitioningRewrite( + Seq(1 -> "a").toDF("i", "j").join(Seq(1 -> "a").toDF("m", "n"), $"i" === $"m"), + classOf[PartitioningCollection]) + } + } } // Used for unit-testing EnsureRequirements