diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 6e7bcb8825488..097ea61f13832 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -251,7 +251,8 @@ case class SortMergeJoinExec( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, - cleanupResources + cleanupResources, + condition.isEmpty ) private[this] val joinRow = new JoinedRow @@ -330,7 +331,8 @@ case class SortMergeJoinExec( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, - cleanupResources + cleanupResources, + condition.isEmpty ) private[this] val joinRow = new JoinedRow @@ -653,6 +655,7 @@ case class SortMergeJoinExec( * internal buffer * @param spillThreshold Threshold for number of rows to be spilled by internal buffer * @param eagerCleanupResources the eager cleanup function to be invoked when no join row found + * @param onlyBufferFirstMatch [[bufferMatchingRows]] should buffer only the first matching row */ private[joins] class SortMergeJoinScanner( streamedKeyGenerator: Projection, @@ -662,7 +665,8 @@ private[joins] class SortMergeJoinScanner( bufferedIter: RowIterator, inMemoryThreshold: Int, spillThreshold: Int, - eagerCleanupResources: () => Unit) { + eagerCleanupResources: () => Unit, + onlyBufferFirstMatch: Boolean = false) { private[this] var streamedRow: InternalRow = _ private[this] var streamedRowKey: InternalRow = _ private[this] var bufferedRow: InternalRow = _ @@ -673,8 +677,9 @@ private[joins] class SortMergeJoinScanner( */ private[this] var matchJoinKey: InternalRow = _ /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */ - private[this] val bufferedMatches = - new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) + private[this] val bufferedMatches: ExternalAppendOnlyUnsafeRowArray = + new ExternalAppendOnlyUnsafeRowArray(if (onlyBufferFirstMatch) 1 else inMemoryThreshold, + spillThreshold) // Initialization (note: do _not_ want to advance streamed here). advancedBufferedToRowWithNullFreeJoinKey() @@ -834,7 +839,9 @@ private[joins] class SortMergeJoinScanner( matchJoinKey = streamedRowKey.copy() bufferedMatches.clear() do { - bufferedMatches.add(bufferedRow.asInstanceOf[UnsafeRow]) + if (!onlyBufferFirstMatch || bufferedMatches.isEmpty) { + bufferedMatches.add(bufferedRow.asInstanceOf[UnsafeRow]) + } advancedBufferedToRowWithNullFreeJoinKey() } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 93cd84713296b..942cf24a3a873 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -749,6 +749,14 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan ) } + // LEFT SEMI JOIN without bound condition does not spill + assertNotSpilled(sparkContext, "left semi join") { + checkAnswer( + sql("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a WHERE key = 2"), + Row(2, "2") :: Nil + ) + } + val expected = new ListBuffer[Row]() expected.append( Row(1, "1", 1, 1), Row(1, "1", 1, 2),