diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 8184baf50b042..e4a8bd6df8065 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -24,8 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, - SortMergeJoinExec} +import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.internal.SQLConf /** @@ -117,25 +116,41 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } private def reorder( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], + leftKeys: IndexedSeq[Expression], + rightKeys: IndexedSeq[Expression], expectedOrderOfKeys: Seq[Expression], currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { - val leftKeysBuffer = ArrayBuffer[Expression]() - val rightKeysBuffer = ArrayBuffer[Expression]() - val pickedIndexes = mutable.Set[Int]() - val keysAndIndexes = currentOrderOfKeys.zipWithIndex - - expectedOrderOfKeys.foreach(expression => { - val index = keysAndIndexes.find { case (e, idx) => - // As we may have the same key used many times, we need to filter out its occurrence we - // have already used. - e.semanticEquals(expression) && !pickedIndexes.contains(idx) - }.map(_._2).get - pickedIndexes += index - leftKeysBuffer.append(leftKeys(index)) - rightKeysBuffer.append(rightKeys(index)) - }) + if (expectedOrderOfKeys.size != currentOrderOfKeys.size) { + return (leftKeys, rightKeys) + } + + // Build a lookup between an expression and the positions its holds in the current key seq. + val keyToIndexMap = mutable.Map.empty[Expression, mutable.BitSet] + currentOrderOfKeys.zipWithIndex.foreach { + case (key, index) => + keyToIndexMap.getOrElseUpdate(key.canonicalized, mutable.BitSet.empty).add(index) + } + + // Reorder the keys. + val leftKeysBuffer = new ArrayBuffer[Expression](leftKeys.size) + val rightKeysBuffer = new ArrayBuffer[Expression](rightKeys.size) + val iterator = expectedOrderOfKeys.iterator + while (iterator.hasNext) { + // Lookup the current index of this key. + keyToIndexMap.get(iterator.next().canonicalized) match { + case Some(indices) if indices.nonEmpty => + // Take the first available index from the map. + val index = indices.firstKey + indices.remove(index) + + // Add the keys for that index to the reordered keys. + leftKeysBuffer += leftKeys(index) + rightKeysBuffer += rightKeys(index) + case _ => + // The expression cannot be found, or we have exhausted all indices for that expression. + return (leftKeys, rightKeys) + } + } (leftKeysBuffer, rightKeysBuffer) } @@ -145,20 +160,13 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { leftPartitioning: Partitioning, rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = { if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) { - leftPartitioning match { - case HashPartitioning(leftExpressions, _) - if leftExpressions.length == leftKeys.length && - leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) => - reorder(leftKeys, rightKeys, leftExpressions, leftKeys) - - case _ => rightPartitioning match { - case HashPartitioning(rightExpressions, _) - if rightExpressions.length == rightKeys.length && - rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) => - reorder(leftKeys, rightKeys, rightExpressions, rightKeys) - - case _ => (leftKeys, rightKeys) - } + (leftPartitioning, rightPartitioning) match { + case (HashPartitioning(leftExpressions, _), _) => + reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leftExpressions, leftKeys) + case (_, HashPartitioning(rightExpressions, _)) => + reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys) + case _ => + (leftKeys, rightKeys) } } else { (leftKeys, rightKeys) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 531cc8660b6ef..5212209b56f69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -897,6 +897,26 @@ class JoinSuite extends QueryTest with SharedSQLContext { } } + test("SPARK-27485: EnsureRequirements should not fail join with duplicate keys") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "2", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val tbl_a = spark.range(40) + .select($"id" as "x", $"id" % 10 as "y") + .repartition(2, $"x", $"y", $"x") + .as("tbl_a") + + val tbl_b = spark.range(20) + .select($"id" as "x", $"id" % 2 as "y1", $"id" % 20 as "y2") + .as("tbl_b") + + val res = tbl_a + .join(tbl_b, + $"tbl_a.x" === $"tbl_b.x" && $"tbl_a.y" === $"tbl_b.y1" && $"tbl_a.y" === $"tbl_b.y2") + .select($"tbl_a.x") + checkAnswer(res, Row(0L) :: Row(1L) :: Nil) + } + } + test("SPARK-26352: join reordering should not change the order of columns") { withTable("tab1", "tab2", "tab3") { spark.sql("select 1 as x, 100 as y").write.saveAsTable("tab1") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index c2d9e54981928..f4dba4ac6d986 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -696,6 +696,32 @@ class PlannerSuite extends SharedSQLContext { } } + test("SPARK-27485: EnsureRequirements.reorder should handle duplicate expressions") { + val plan1 = DummySparkPlan( + outputPartitioning = HashPartitioning(exprA :: exprB :: exprA :: Nil, 5)) + val plan2 = DummySparkPlan() + val smjExec = SortMergeJoinExec( + leftKeys = exprA :: exprB :: exprB :: Nil, + rightKeys = exprA :: exprC :: exprC :: Nil, + joinType = Inner, + condition = None, + left = plan1, + right = plan2) + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec) + outputPlan match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, + ShuffleExchangeExec(HashPartitioning(leftPartitioningExpressions, _), _), _), + SortExec(_, _, + ShuffleExchangeExec(HashPartitioning(rightPartitioningExpressions, _), _), _)) => + assert(leftKeys === smjExec.leftKeys) + assert(rightKeys === smjExec.rightKeys) + assert(leftKeys === leftPartitioningExpressions) + assert(rightKeys === rightPartitioningExpressions) + case _ => fail(outputPlan.toString) + } + } + test("SPARK-24500: create union with stream of children") { val df = Union(Stream( Range(1, 1, 1, 1),