From 89065f6cc2c41ae5d7d330dd035902e1bf28b876 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 28 Sep 2022 15:51:38 +0200 Subject: [PATCH] [SPARK-42136][SQL] Refactor BroadcastHashJoinExec output partitioning generation --- .../joins/BroadcastHashJoinExec.scala | 27 ++++--------------- .../execution/joins/BroadcastJoinSuite.scala | 6 ++--- 2 files changed, 8 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 69c760b5a00b1..08eaacca2f49f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -112,28 +112,11 @@ case class BroadcastHashJoinExec( // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y"). // The expanded expressions are returned as PartitioningCollection. private def expandOutputPartitioning(partitioning: HashPartitioning): PartitioningCollection = { - val maxNumCombinations = conf.broadcastHashJoinOutputPartitioningExpandLimit - var currentNumCombinations = 0 - - def generateExprCombinations( - current: Seq[Expression], - accumulated: Seq[Expression]): Seq[Seq[Expression]] = { - if (currentNumCombinations >= maxNumCombinations) { - Nil - } else if (current.isEmpty) { - currentNumCombinations += 1 - Seq(accumulated) - } else { - val buildKeysOpt = streamedKeyToBuildKeyMapping.get(current.head.canonicalized) - generateExprCombinations(current.tail, accumulated :+ current.head) ++ - buildKeysOpt.map(_.flatMap(b => generateExprCombinations(current.tail, accumulated :+ b))) - .getOrElse(Nil) - } - } - - PartitioningCollection( - generateExprCombinations(partitioning.expressions, Nil) - .map(HashPartitioning(_, partitioning.numPartitions))) + PartitioningCollection(partitioning.multiTransformDown { + case e: Expression if streamedKeyToBuildKeyMapping.contains(e.canonicalized) => + e +: streamedKeyToBuildKeyMapping(e.canonicalized) + }.asInstanceOf[Stream[HashPartitioning]] + .take(conf.broadcastHashJoinOutputPartitioningExpandLimit)) } protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 6333808b42086..47714c669d5ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -554,8 +554,8 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils right = DummySparkPlan()) var expected = PartitioningCollection(Seq( HashPartitioning(Seq(l1, l2, l3), 1), - HashPartitioning(Seq(l1, l2, r2), 1), HashPartitioning(Seq(l1, r1, l3), 1), + HashPartitioning(Seq(l1, l2, r2), 1), HashPartitioning(Seq(l1, r1, r2), 1))) assert(bhj.outputPartitioning === expected) @@ -571,8 +571,8 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils right = DummySparkPlan()) expected = PartitioningCollection(Seq( HashPartitioning(Seq(l1, l2), 1), - HashPartitioning(Seq(l1, r2), 1), HashPartitioning(Seq(r1, l2), 1), + HashPartitioning(Seq(l1, r2), 1), HashPartitioning(Seq(r1, r2), 1), HashPartitioning(Seq(l3), 1), HashPartitioning(Seq(r3), 1))) @@ -623,8 +623,8 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils val expected = Seq( HashPartitioning(Seq(l1, l2), 1), - HashPartitioning(Seq(l1, r2), 1), HashPartitioning(Seq(r1, l2), 1), + HashPartitioning(Seq(l1, r2), 1), HashPartitioning(Seq(r1, r2), 1)) Seq(1, 2, 3, 4).foreach { limit =>