diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 69c760b5a00b1..08eaacca2f49f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -112,28 +112,11 @@ case class BroadcastHashJoinExec( // Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y"). // The expanded expressions are returned as PartitioningCollection. private def expandOutputPartitioning(partitioning: HashPartitioning): PartitioningCollection = { - val maxNumCombinations = conf.broadcastHashJoinOutputPartitioningExpandLimit - var currentNumCombinations = 0 - - def generateExprCombinations( - current: Seq[Expression], - accumulated: Seq[Expression]): Seq[Seq[Expression]] = { - if (currentNumCombinations >= maxNumCombinations) { - Nil - } else if (current.isEmpty) { - currentNumCombinations += 1 - Seq(accumulated) - } else { - val buildKeysOpt = streamedKeyToBuildKeyMapping.get(current.head.canonicalized) - generateExprCombinations(current.tail, accumulated :+ current.head) ++ - buildKeysOpt.map(_.flatMap(b => generateExprCombinations(current.tail, accumulated :+ b))) - .getOrElse(Nil) - } - } - - PartitioningCollection( - generateExprCombinations(partitioning.expressions, Nil) - .map(HashPartitioning(_, partitioning.numPartitions))) + PartitioningCollection(partitioning.multiTransformDown { + case e: Expression if streamedKeyToBuildKeyMapping.contains(e.canonicalized) => + e +: streamedKeyToBuildKeyMapping(e.canonicalized) + }.asInstanceOf[Stream[HashPartitioning]] + .take(conf.broadcastHashJoinOutputPartitioningExpandLimit)) } protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 6333808b42086..47714c669d5ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -554,8 +554,8 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils right = DummySparkPlan()) var expected = PartitioningCollection(Seq( HashPartitioning(Seq(l1, l2, l3), 1), - HashPartitioning(Seq(l1, l2, r2), 1), HashPartitioning(Seq(l1, r1, l3), 1), + HashPartitioning(Seq(l1, l2, r2), 1), HashPartitioning(Seq(l1, r1, r2), 1))) assert(bhj.outputPartitioning === expected) @@ -571,8 +571,8 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils right = DummySparkPlan()) expected = PartitioningCollection(Seq( HashPartitioning(Seq(l1, l2), 1), - HashPartitioning(Seq(l1, r2), 1), HashPartitioning(Seq(r1, l2), 1), + HashPartitioning(Seq(l1, r2), 1), HashPartitioning(Seq(r1, r2), 1), HashPartitioning(Seq(l3), 1), HashPartitioning(Seq(r3), 1))) @@ -623,8 +623,8 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils val expected = Seq( HashPartitioning(Seq(l1, l2), 1), - HashPartitioning(Seq(l1, r2), 1), HashPartitioning(Seq(r1, l2), 1), + HashPartitioning(Seq(l1, r2), 1), HashPartitioning(Seq(r1, r2), 1)) Seq(1, 2, 3, 4).foreach { limit =>