Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.vectorized._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -169,8 +169,8 @@ case class InMemoryTableScanExec(
// But the cached version could alias output, so we need to replace output.
override def outputPartitioning: Partitioning = {
relation.cachedPlan.outputPartitioning match {
case h: HashPartitioning => updateAttribute(h).asInstanceOf[HashPartitioning]
case _ => relation.cachedPlan.outputPartitioning
case e: Expression => updateAttribute(e).asInstanceOf[Partitioning]
case other => other
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -70,7 +70,7 @@ case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchan
}

override def outputPartitioning: Partitioning = child.outputPartitioning match {
case h: HashPartitioning => h.copy(expressions = h.expressions.map(updateAttr))
case e: Expression => updateAttr(e).asInstanceOf[Partitioning]
case other => other
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
package org.apache.spark.sql.execution

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{execution, Row}
import org.apache.spark.sql.{execution, DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort, Union}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -686,6 +686,66 @@ class PlannerSuite extends SharedSQLContext {
Range(1, 2, 1, 1)))
df.queryExecution.executedPlan.execute()
}

test("SPARK-24556: always rewrite output partitioning in ReusedExchangeExec " +
"and InMemoryTableScanExec") {
def checkOutputPartitioningRewrite(
plans: Seq[SparkPlan],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now we can take a single spark plan

Copy link
Contributor Author

@yucai yucai Jun 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you think if we merge check*OutputPartitioningRewrite together?

    def checkPlanAndOutputPartitioningRewrite(
        df: DataFrame,
        expectedPlanClass: Class[_],
        expectedPartitioningClass: Class[_]): Unit = {
      val plans = df.queryExecution.executedPlan.collect {
        case r: ReusedExchangeExec => r
        case m: InMemoryTableScanExec => m
      }
      assert(plans.size == 1)
      val plan = plans.head
      assert(plan.getClass == expectedPlanClass)
      val partitioning = plan.outputPartitioning
      assert(partitioning.getClass == expectedPartitioningClass)
      val partitionedAttrs = partitioning.asInstanceOf[Expression].references
      assert(partitionedAttrs.subsetOf(plan.outputSet))
    }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan I still use Seq, so I can make checkReusedExchangeOutputPartitioningRewrite and checkInMemoryTableScanOutputPartitioningRewrite simpler. Kindly let me know if you have better idea.

expectedPartitioningClass: Class[_]): Unit = {
assert(plans.size == 1)
val plan = plans.head
val partitioning = plan.outputPartitioning
assert(partitioning.getClass == expectedPartitioningClass)
val partitionedAttrs = partitioning.asInstanceOf[Expression].references
assert(partitionedAttrs.subsetOf(plan.outputSet))
}

def checkReusedExchangeOutputPartitioningRewrite(
df: DataFrame,
expectedPartitioningClass: Class[_]): Unit = {
val reusedExchange = df.queryExecution.executedPlan.collect {
case r: ReusedExchangeExec => r
}
checkOutputPartitioningRewrite(reusedExchange, expectedPartitioningClass)
}

def checkInMemoryTableScanOutputPartitioningRewrite(
df: DataFrame,
expectedPartitioningClass: Class[_]): Unit = {
val inMemoryScan = df.queryExecution.executedPlan.collect {
case m: InMemoryTableScanExec => m
}
checkOutputPartitioningRewrite(inMemoryScan, expectedPartitioningClass)
}

// ReusedExchange is HashPartitioning
val df1 = Seq(1 -> "a").toDF("i", "j").repartition($"i")
val df2 = Seq(1 -> "a").toDF("i", "j").repartition($"i")
checkReusedExchangeOutputPartitioningRewrite(df1.union(df2), classOf[HashPartitioning])

// ReusedExchange is RangePartitioning
val df3 = Seq(1 -> "a").toDF("i", "j").orderBy($"i")
val df4 = Seq(1 -> "a").toDF("i", "j").orderBy($"i")
checkReusedExchangeOutputPartitioningRewrite(df3.union(df4), classOf[RangePartitioning])

// InMemoryTableScan is HashPartitioning
Seq(1 -> "a").toDF("i", "j").repartition($"i").persist()
checkInMemoryTableScanOutputPartitioningRewrite(
Seq(1 -> "a").toDF("i", "j").repartition($"i"), classOf[HashPartitioning])

// InMemoryTableScan is RangePartitioning
spark.range(1, 100, 1, 10).toDF().persist()
checkInMemoryTableScanOutputPartitioningRewrite(
spark.range(1, 100, 1, 10).toDF(), classOf[RangePartitioning])

// InMemoryTableScan is PartitioningCollection
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
Seq(1 -> "a").toDF("i", "j").join(Seq(1 -> "a").toDF("m", "n"), $"i" === $"m").persist()
checkInMemoryTableScanOutputPartitioningRewrite(
Seq(1 -> "a").toDF("i", "j").join(Seq(1 -> "a").toDF("m", "n"), $"i" === $"m"),
classOf[PartitioningCollection])
}
}
}

// Used for unit-testing EnsureRequirements
Expand Down