Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec,
SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf

/**
Expand Down Expand Up @@ -117,25 +116,41 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
}

private def reorder(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
leftKeys: IndexedSeq[Expression],
rightKeys: IndexedSeq[Expression],
expectedOrderOfKeys: Seq[Expression],
currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
val leftKeysBuffer = ArrayBuffer[Expression]()
val rightKeysBuffer = ArrayBuffer[Expression]()
val pickedIndexes = mutable.Set[Int]()
val keysAndIndexes = currentOrderOfKeys.zipWithIndex

expectedOrderOfKeys.foreach(expression => {
val index = keysAndIndexes.find { case (e, idx) =>
// As we may have the same key used many times, we need to filter out its occurrence we
// have already used.
e.semanticEquals(expression) && !pickedIndexes.contains(idx)
}.map(_._2).get
pickedIndexes += index
leftKeysBuffer.append(leftKeys(index))
rightKeysBuffer.append(rightKeys(index))
})
if (expectedOrderOfKeys.size != currentOrderOfKeys.size) {
return (leftKeys, rightKeys)
}

// Build a lookup between an expression and the positions its holds in the current key seq.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous implementation had the potential for quadratic behavior in quite a few places so I changed all that. I might have gotten carried away here, especially given the fact that number of keys is often quite low.

val keyToIndexMap = mutable.Map.empty[Expression, mutable.BitSet]
currentOrderOfKeys.zipWithIndex.foreach {
case (key, index) =>
keyToIndexMap.getOrElseUpdate(key.canonicalized, mutable.BitSet.empty).add(index)
}

// Reorder the keys.
val leftKeysBuffer = new ArrayBuffer[Expression](leftKeys.size)
val rightKeysBuffer = new ArrayBuffer[Expression](rightKeys.size)
val iterator = expectedOrderOfKeys.iterator
while (iterator.hasNext) {
// Lookup the current index of this key.
keyToIndexMap.get(iterator.next().canonicalized) match {
case Some(indices) if indices.nonEmpty =>
// Take the first available index from the map.
val index = indices.firstKey
indices.remove(index)

// Add the keys for that index to the reordered keys.
leftKeysBuffer += leftKeys(index)
rightKeysBuffer += rightKeys(index)
case _ =>
// The expression cannot be found, or we have exhausted all indices for that expression.
return (leftKeys, rightKeys)
}
}
(leftKeysBuffer, rightKeysBuffer)
}

Expand All @@ -145,20 +160,13 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
leftPartitioning: Partitioning,
rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {
if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
leftPartitioning match {
case HashPartitioning(leftExpressions, _)
if leftExpressions.length == leftKeys.length &&
leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) =>
reorder(leftKeys, rightKeys, leftExpressions, leftKeys)

case _ => rightPartitioning match {
case HashPartitioning(rightExpressions, _)
if rightExpressions.length == rightKeys.length &&
rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) =>
reorder(leftKeys, rightKeys, rightExpressions, rightKeys)

case _ => (leftKeys, rightKeys)
}
(leftPartitioning, rightPartitioning) match {
case (HashPartitioning(leftExpressions, _), _) =>
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leftExpressions, leftKeys)
case (_, HashPartitioning(rightExpressions, _)) =>
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys)
case _ =>
(leftKeys, rightKeys)
}
} else {
(leftKeys, rightKeys)
Expand Down
20 changes: 20 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,26 @@ class JoinSuite extends QueryTest with SharedSQLContext {
}
}

test("SPARK-27485: EnsureRequirements should not fail join with duplicate keys") {
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "2",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
val tbl_a = spark.range(40)
.select($"id" as "x", $"id" % 10 as "y")
.repartition(2, $"x", $"y", $"x")
.as("tbl_a")

val tbl_b = spark.range(20)
.select($"id" as "x", $"id" % 2 as "y1", $"id" % 20 as "y2")
.as("tbl_b")

val res = tbl_a
.join(tbl_b,
$"tbl_a.x" === $"tbl_b.x" && $"tbl_a.y" === $"tbl_b.y1" && $"tbl_a.y" === $"tbl_b.y2")
.select($"tbl_a.x")
checkAnswer(res, Row(0L) :: Row(1L) :: Nil)
}
}

test("SPARK-26352: join reordering should not change the order of columns") {
withTable("tab1", "tab2", "tab3") {
spark.sql("select 1 as x, 100 as y").write.saveAsTable("tab1")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,32 @@ class PlannerSuite extends SharedSQLContext {
}
}

test("SPARK-27485: EnsureRequirements.reorder should handle duplicate expressions") {
val plan1 = DummySparkPlan(
outputPartitioning = HashPartitioning(exprA :: exprB :: exprA :: Nil, 5))
val plan2 = DummySparkPlan()
val smjExec = SortMergeJoinExec(
leftKeys = exprA :: exprB :: exprB :: Nil,
rightKeys = exprA :: exprC :: exprC :: Nil,
joinType = Inner,
condition = None,
left = plan1,
right = plan2)
val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec)
outputPlan match {
case SortMergeJoinExec(leftKeys, rightKeys, _, _,
SortExec(_, _,
ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _), _),
SortExec(_, _,
ShuffleExchangeExec(HashPartitioning(rightPartitioningExpressions, _), _), _)) =>
assert(leftKeys === smjExec.leftKeys)
assert(rightKeys === smjExec.rightKeys)
assert(leftKeys === leftPartitioningExpressions)
assert(rightKeys === rightPartitioningExpressions)
case _ => fail(outputPlan.toString)
}
}

test("SPARK-24500: create union with stream of children") {
val df = Union(Stream(
Range(1, 1, 1, 1),
Expand Down