From 1a49356b8b7e021c1c8f7176e70b3d26ce8fc491 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Fri, 28 Aug 2020 17:20:06 +0200 Subject: [PATCH 1/6] [SPARK-32730][SQL] Improve LeftSemi SortMergeJoin right side buffering --- .../ExternalAppendOnlyUnsafeRowArray.scala | 10 +++- .../execution/joins/SortMergeJoinExec.scala | 52 +++++++++++++++---- 2 files changed, 50 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala index ac282ea2e94f5..1510369c265de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala @@ -30,6 +30,14 @@ import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray.DefaultIn import org.apache.spark.storage.BlockManager import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator} +trait AppendOnlyUnsafeRowArray { + def clear() + def add(row: UnsafeRow) + def isEmpty: Boolean + def length: Int + def generateIterator(): Iterator[UnsafeRow] +} + /** * An append-only array for [[UnsafeRow]]s that strictly keeps content in an in-memory array * until [[numRowsInMemoryBufferThreshold]] is reached post which it will switch to a mode which @@ -50,7 +58,7 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( initialSize: Int, pageSizeBytes: Long, numRowsInMemoryBufferThreshold: Int, - numRowsSpillThreshold: Int) extends Logging { + numRowsSpillThreshold: Int) extends AppendOnlyUnsafeRowArray with Logging { def this(numRowsInMemoryBufferThreshold: Int, numRowsSpillThreshold: Int) { this( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 6e7bcb8825488..937bb83435507 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -146,7 +146,7 @@ case class SortMergeJoinExec( case _: InnerLike => new RowIterator { private[this] var currentLeftRow: InternalRow = _ - private[this] var currentRightMatches: ExternalAppendOnlyUnsafeRowArray = _ + private[this] var currentRightMatches: AppendOnlyUnsafeRowArray = _ private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null private[this] val smjScanner = new SortMergeJoinScanner( createLeftKeyGenerator(), @@ -156,7 +156,8 @@ case class SortMergeJoinExec( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, - cleanupResources + cleanupResources, + false ) private[this] val joinRow = new JoinedRow @@ -201,7 +202,8 @@ case class SortMergeJoinExec( bufferedIter = RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, - cleanupResources + cleanupResources, + false ) val rightNullRow = new GenericInternalRow(right.output.length) new LeftOuterIterator( @@ -216,7 +218,8 @@ case class SortMergeJoinExec( bufferedIter = RowIterator.fromScala(leftIter), inMemoryThreshold, spillThreshold, - cleanupResources + cleanupResources, + false ) val leftNullRow = new GenericInternalRow(left.output.length) new RightOuterIterator( @@ -251,7 +254,8 @@ case class SortMergeJoinExec( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, - cleanupResources + cleanupResources, + condition.isEmpty ) private[this] val joinRow = new JoinedRow @@ -287,7 +291,8 @@ case class SortMergeJoinExec( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, - cleanupResources + cleanupResources, + false ) private[this] val joinRow = new JoinedRow @@ -330,7 +335,8 @@ case class SortMergeJoinExec( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, - cleanupResources + cleanupResources, + condition.isEmpty ) private[this] val joinRow = new JoinedRow @@ -662,7 +668,8 @@ private[joins] class SortMergeJoinScanner( bufferedIter: RowIterator, inMemoryThreshold: Int, spillThreshold: Int, - eagerCleanupResources: () => Unit) { + eagerCleanupResources: () => Unit, + bufferFirstOnly: Boolean) { private[this] var streamedRow: InternalRow = _ private[this] var streamedRowKey: InternalRow = _ private[this] var bufferedRow: InternalRow = _ @@ -673,8 +680,29 @@ private[joins] class SortMergeJoinScanner( */ private[this] var matchJoinKey: InternalRow = _ /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */ - private[this] val bufferedMatches = + private[this] val bufferedMatches: AppendOnlyUnsafeRowArray = if (bufferFirstOnly) { + new AppendOnlyUnsafeRowArray { + var buffer: UnsafeRow = null + + override def clear(): Unit = { + buffer = null + } + + override def add(row: UnsafeRow): Unit = { + assert(buffer == null) + + buffer = row + } + + override def isEmpty: Boolean = buffer == null + + override def length: Int = if (buffer == null) 0 else 1 + + override def generateIterator(): Iterator[UnsafeRow] = Iterator(buffer) + } + } else { new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) + } // Initialization (note: do _not_ want to advance streamed here). advancedBufferedToRowWithNullFreeJoinKey() @@ -683,7 +711,7 @@ private[joins] class SortMergeJoinScanner( def getStreamedRow: InternalRow = streamedRow - def getBufferedMatches: ExternalAppendOnlyUnsafeRowArray = bufferedMatches + def getBufferedMatches: AppendOnlyUnsafeRowArray = bufferedMatches /** * Advances both input iterators, stopping when we have found rows with matching join keys. If no @@ -834,7 +862,9 @@ private[joins] class SortMergeJoinScanner( matchJoinKey = streamedRowKey.copy() bufferedMatches.clear() do { - bufferedMatches.add(bufferedRow.asInstanceOf[UnsafeRow]) + if (!bufferFirstOnly || bufferedMatches.isEmpty) { + bufferedMatches.add(bufferedRow.asInstanceOf[UnsafeRow]) + } advancedBufferedToRowWithNullFreeJoinKey() } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) } From acc6646e7e83fa9e3f082b1aaa8c6227e3d8a7cf Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Sat, 29 Aug 2020 10:30:41 +0200 Subject: [PATCH 2/6] add UT --- .../src/test/scala/org/apache/spark/sql/JoinSuite.scala | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 93cd84713296b..892dcf0d58963 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -749,6 +749,15 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan ) } + // LEFT SEMI JOIN without bound condition does not use [[ExternalAppendOnlyUnsafeRowArray]] + // so should not cause any spill + assertNotSpilled(sparkContext, "left semi join") { + checkAnswer( + sql("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a WHERE key = 2"), + Row(2, "2") :: Nil + ) + } + val expected = new ListBuffer[Row]() expected.append( Row(1, "1", 1, 1), Row(1, "1", 1, 2), From 037b876bb1726d0f8a7eb6b574a0daedeb5d53e4 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 2 Sep 2020 21:07:19 +0200 Subject: [PATCH 3/6] drop AppendOnlyUnsafeRowArray trait --- .../ExternalAppendOnlyUnsafeRowArray.scala | 10 +------ .../execution/joins/SortMergeJoinExec.scala | 30 ++++--------------- .../org/apache/spark/sql/JoinSuite.scala | 3 +- 3 files changed, 7 insertions(+), 36 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala index 1510369c265de..ac282ea2e94f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala @@ -30,14 +30,6 @@ import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray.DefaultIn import org.apache.spark.storage.BlockManager import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator} -trait AppendOnlyUnsafeRowArray { - def clear() - def add(row: UnsafeRow) - def isEmpty: Boolean - def length: Int - def generateIterator(): Iterator[UnsafeRow] -} - /** * An append-only array for [[UnsafeRow]]s that strictly keeps content in an in-memory array * until [[numRowsInMemoryBufferThreshold]] is reached post which it will switch to a mode which @@ -58,7 +50,7 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( initialSize: Int, pageSizeBytes: Long, numRowsInMemoryBufferThreshold: Int, - numRowsSpillThreshold: Int) extends AppendOnlyUnsafeRowArray with Logging { + numRowsSpillThreshold: Int) extends Logging { def this(numRowsInMemoryBufferThreshold: Int, numRowsSpillThreshold: Int) { this( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 937bb83435507..84f09d4cb00cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -146,7 +146,7 @@ case class SortMergeJoinExec( case _: InnerLike => new RowIterator { private[this] var currentLeftRow: InternalRow = _ - private[this] var currentRightMatches: AppendOnlyUnsafeRowArray = _ + private[this] var currentRightMatches: ExternalAppendOnlyUnsafeRowArray = _ private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null private[this] val smjScanner = new SortMergeJoinScanner( createLeftKeyGenerator(), @@ -680,29 +680,9 @@ private[joins] class SortMergeJoinScanner( */ private[this] var matchJoinKey: InternalRow = _ /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */ - private[this] val bufferedMatches: AppendOnlyUnsafeRowArray = if (bufferFirstOnly) { - new AppendOnlyUnsafeRowArray { - var buffer: UnsafeRow = null - - override def clear(): Unit = { - buffer = null - } - - override def add(row: UnsafeRow): Unit = { - assert(buffer == null) - - buffer = row - } - - override def isEmpty: Boolean = buffer == null - - override def length: Int = if (buffer == null) 0 else 1 - - override def generateIterator(): Iterator[UnsafeRow] = Iterator(buffer) - } - } else { - new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) - } + private[this] val bufferedMatches: ExternalAppendOnlyUnsafeRowArray = + new ExternalAppendOnlyUnsafeRowArray(if (bufferFirstOnly) 1 else inMemoryThreshold, + spillThreshold) // Initialization (note: do _not_ want to advance streamed here). advancedBufferedToRowWithNullFreeJoinKey() @@ -711,7 +691,7 @@ private[joins] class SortMergeJoinScanner( def getStreamedRow: InternalRow = streamedRow - def getBufferedMatches: AppendOnlyUnsafeRowArray = bufferedMatches + def getBufferedMatches: ExternalAppendOnlyUnsafeRowArray = bufferedMatches /** * Advances both input iterators, stopping when we have found rows with matching join keys. If no diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 892dcf0d58963..942cf24a3a873 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -749,8 +749,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan ) } - // LEFT SEMI JOIN without bound condition does not use [[ExternalAppendOnlyUnsafeRowArray]] - // so should not cause any spill + // LEFT SEMI JOIN without bound condition does not spill assertNotSpilled(sparkContext, "left semi join") { checkAnswer( sql("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a WHERE key = 2"), From 3937e4ca0b8631d785543c53d492eaa5b17b1ab4 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 3 Sep 2020 09:10:13 +0200 Subject: [PATCH 4/6] address review comments --- .../execution/joins/SortMergeJoinExec.scala | 23 +++++++------------ 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 84f09d4cb00cf..e33d243b82ee8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -156,9 +156,7 @@ case class SortMergeJoinExec( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, - cleanupResources, - false - ) + cleanupResources) private[this] val joinRow = new JoinedRow if (smjScanner.findNextInnerJoinRows()) { @@ -202,9 +200,7 @@ case class SortMergeJoinExec( bufferedIter = RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, - cleanupResources, - false - ) + cleanupResources) val rightNullRow = new GenericInternalRow(right.output.length) new LeftOuterIterator( smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows).toScala @@ -218,9 +214,7 @@ case class SortMergeJoinExec( bufferedIter = RowIterator.fromScala(leftIter), inMemoryThreshold, spillThreshold, - cleanupResources, - false - ) + cleanupResources) val leftNullRow = new GenericInternalRow(left.output.length) new RightOuterIterator( smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala @@ -291,9 +285,7 @@ case class SortMergeJoinExec( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, - cleanupResources, - false - ) + cleanupResources) private[this] val joinRow = new JoinedRow override def advanceNext(): Boolean = { @@ -659,6 +651,7 @@ case class SortMergeJoinExec( * internal buffer * @param spillThreshold Threshold for number of rows to be spilled by internal buffer * @param eagerCleanupResources the eager cleanup function to be invoked when no join row found + * @param matchedBufferFirstOnly [[bufferMatchingRows]] should buffer only the first matching row */ private[joins] class SortMergeJoinScanner( streamedKeyGenerator: Projection, @@ -669,7 +662,7 @@ private[joins] class SortMergeJoinScanner( inMemoryThreshold: Int, spillThreshold: Int, eagerCleanupResources: () => Unit, - bufferFirstOnly: Boolean) { + matchedBufferFirstOnly: Boolean = false) { private[this] var streamedRow: InternalRow = _ private[this] var streamedRowKey: InternalRow = _ private[this] var bufferedRow: InternalRow = _ @@ -681,7 +674,7 @@ private[joins] class SortMergeJoinScanner( private[this] var matchJoinKey: InternalRow = _ /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */ private[this] val bufferedMatches: ExternalAppendOnlyUnsafeRowArray = - new ExternalAppendOnlyUnsafeRowArray(if (bufferFirstOnly) 1 else inMemoryThreshold, + new ExternalAppendOnlyUnsafeRowArray(if (matchedBufferFirstOnly) 1 else inMemoryThreshold, spillThreshold) // Initialization (note: do _not_ want to advance streamed here). @@ -842,7 +835,7 @@ private[joins] class SortMergeJoinScanner( matchJoinKey = streamedRowKey.copy() bufferedMatches.clear() do { - if (!bufferFirstOnly || bufferedMatches.isEmpty) { + if (!matchedBufferFirstOnly || bufferedMatches.isEmpty) { bufferedMatches.add(bufferedRow.asInstanceOf[UnsafeRow]) } advancedBufferedToRowWithNullFreeJoinKey() From 5cf3ab329d77b4cc525b2f4ca1992b75782a7f05 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 3 Sep 2020 09:39:34 +0200 Subject: [PATCH 5/6] rename param --- .../spark/sql/execution/joins/SortMergeJoinExec.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index e33d243b82ee8..5017210d7845a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -651,7 +651,7 @@ case class SortMergeJoinExec( * internal buffer * @param spillThreshold Threshold for number of rows to be spilled by internal buffer * @param eagerCleanupResources the eager cleanup function to be invoked when no join row found - * @param matchedBufferFirstOnly [[bufferMatchingRows]] should buffer only the first matching row + * @param onlyBufferFirstMatch [[bufferMatchingRows]] should buffer only the first matching row */ private[joins] class SortMergeJoinScanner( streamedKeyGenerator: Projection, @@ -662,7 +662,7 @@ private[joins] class SortMergeJoinScanner( inMemoryThreshold: Int, spillThreshold: Int, eagerCleanupResources: () => Unit, - matchedBufferFirstOnly: Boolean = false) { + onlyBufferFirstMatch: Boolean = false) { private[this] var streamedRow: InternalRow = _ private[this] var streamedRowKey: InternalRow = _ private[this] var bufferedRow: InternalRow = _ @@ -674,7 +674,7 @@ private[joins] class SortMergeJoinScanner( private[this] var matchJoinKey: InternalRow = _ /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */ private[this] val bufferedMatches: ExternalAppendOnlyUnsafeRowArray = - new ExternalAppendOnlyUnsafeRowArray(if (matchedBufferFirstOnly) 1 else inMemoryThreshold, + new ExternalAppendOnlyUnsafeRowArray(if (onlyBufferFirstMatch) 1 else inMemoryThreshold, spillThreshold) // Initialization (note: do _not_ want to advance streamed here). @@ -835,7 +835,7 @@ private[joins] class SortMergeJoinScanner( matchJoinKey = streamedRowKey.copy() bufferedMatches.clear() do { - if (!matchedBufferFirstOnly || bufferedMatches.isEmpty) { + if (!onlyBufferFirstMatch || bufferedMatches.isEmpty) { bufferedMatches.add(bufferedRow.asInstanceOf[UnsafeRow]) } advancedBufferedToRowWithNullFreeJoinKey() From f699118df05e25193a81c9bedce5b4eb10023079 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 3 Sep 2020 09:52:20 +0200 Subject: [PATCH 6/6] revert accidental changes --- .../sql/execution/joins/SortMergeJoinExec.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 5017210d7845a..097ea61f13832 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -156,7 +156,8 @@ case class SortMergeJoinExec( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, - cleanupResources) + cleanupResources + ) private[this] val joinRow = new JoinedRow if (smjScanner.findNextInnerJoinRows()) { @@ -200,7 +201,8 @@ case class SortMergeJoinExec( bufferedIter = RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, - cleanupResources) + cleanupResources + ) val rightNullRow = new GenericInternalRow(right.output.length) new LeftOuterIterator( smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows).toScala @@ -214,7 +216,8 @@ case class SortMergeJoinExec( bufferedIter = RowIterator.fromScala(leftIter), inMemoryThreshold, spillThreshold, - cleanupResources) + cleanupResources + ) val leftNullRow = new GenericInternalRow(left.output.length) new RightOuterIterator( smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala @@ -285,7 +288,8 @@ case class SortMergeJoinExec( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, - cleanupResources) + cleanupResources + ) private[this] val joinRow = new JoinedRow override def advanceNext(): Boolean = {