diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 2c57956de5bc..17f487947aec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.joins -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -266,7 +264,10 @@ case class SortMergeJoinExec( rightIter = RowIterator.fromScala(rightIter), boundCondition, leftNullRow, - rightNullRow) + rightNullRow, + inMemoryThreshold, + spillThreshold, + cleanupResources) new FullOuterIterator( smjScanner, @@ -998,7 +999,10 @@ private class SortMergeFullOuterJoinScanner( rightIter: RowIterator, boundCondition: InternalRow => Boolean, leftNullRow: InternalRow, - rightNullRow: InternalRow) { + rightNullRow: InternalRow, + inMemoryThreshold: Int, + spillThreshold: Int, + eagerCleanupResources: () => Unit) { private[this] val joinedRow: JoinedRow = new JoinedRow() private[this] var leftRow: InternalRow = _ private[this] var leftRowKey: InternalRow = _ @@ -1007,8 +1011,14 @@ private class SortMergeFullOuterJoinScanner( private[this] var leftIndex: Int = 0 private[this] var rightIndex: Int = 0 - private[this] val leftMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] - private[this] val rightMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow] + private[this] var leftMatchesIterator: Iterator[UnsafeRow] = _ + private[this] var rightMatchesIterator: Iterator[UnsafeRow] = _ + private[this] var leftCurrentRow: InternalRow = _ + private[this] var rightCurrentRow: InternalRow = _ + private[this] val leftMatches: ExternalAppendOnlyUnsafeRowArray = + new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) + private[this] val rightMatches: ExternalAppendOnlyUnsafeRowArray = + new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) private[this] var leftMatched: BitSet = new BitSet(1) private[this] var rightMatched: BitSet = new BitSet(1) @@ -1060,23 +1070,32 @@ private class SortMergeFullOuterJoinScanner( rightIndex = 0 while (leftRowKey != null && keyOrdering.compare(leftRowKey, matchingKey) == 0) { - leftMatches += leftRow.copy() + leftMatches.add(leftRow.copy().asInstanceOf[UnsafeRow]) advancedLeft() } while (rightRowKey != null && keyOrdering.compare(rightRowKey, matchingKey) == 0) { - rightMatches += rightRow.copy() + rightMatches.add(rightRow.copy().asInstanceOf[UnsafeRow]) advancedRight() } - if (leftMatches.size <= leftMatched.capacity) { - leftMatched.clearUntil(leftMatches.size) + if (leftMatches.length > 0) { + leftMatchesIterator = leftMatches.generateIterator() + leftCurrentRow = leftMatchesIterator.next() + } + if (rightMatches.length > 0) { + rightMatchesIterator = rightMatches.generateIterator() + rightCurrentRow = rightMatchesIterator.next() + } + + if (leftMatches.length <= leftMatched.capacity) { + leftMatched.clearUntil(leftMatches.length) } else { - leftMatched = new BitSet(leftMatches.size) + leftMatched = new BitSet(leftMatches.length) } - if (rightMatches.size <= rightMatched.capacity) { - rightMatched.clearUntil(rightMatches.size) + if (rightMatches.length <= rightMatched.capacity) { + rightMatched.clearUntil(rightMatches.length) } else { - rightMatched = new BitSet(rightMatches.size) + rightMatched = new BitSet(rightMatches.length) } } @@ -1090,48 +1109,70 @@ private class SortMergeFullOuterJoinScanner( * @return true if a valid match is found, false otherwise. */ private def scanNextInBuffered(): Boolean = { - while (leftIndex < leftMatches.size) { - while (rightIndex < rightMatches.size) { - joinedRow(leftMatches(leftIndex), rightMatches(rightIndex)) + while (leftIndex < leftMatches.length) { + while (rightIndex < rightMatches.length) { + joinedRow(leftCurrentRow, rightCurrentRow) if (boundCondition(joinedRow)) { leftMatched.set(leftIndex) rightMatched.set(rightIndex) rightIndex += 1 + nextRightRow() return true } rightIndex += 1 + nextRightRow() } rightIndex = 0 + if (rightMatches.length > 0) { + rightMatchesIterator = rightMatches.generateIterator() + rightCurrentRow = rightMatchesIterator.next() + } if (!leftMatched.get(leftIndex)) { // the left row has never matched any right row, join it with null row - joinedRow(leftMatches(leftIndex), rightNullRow) + joinedRow(leftCurrentRow, rightNullRow) leftIndex += 1 + nextLeftRow() return true } leftIndex += 1 + nextLeftRow() } - while (rightIndex < rightMatches.size) { + while (rightIndex < rightMatches.length) { if (!rightMatched.get(rightIndex)) { // the right row has never matched any left row, join it with null row - joinedRow(leftNullRow, rightMatches(rightIndex)) + joinedRow(leftNullRow, rightCurrentRow) rightIndex += 1 + nextRightRow() return true } rightIndex += 1 + nextRightRow() } // There are no more valid matches in the left and right buffers false } + private def nextLeftRow(): Unit = { + if (leftIndex < leftMatches.length) { + leftCurrentRow = leftMatchesIterator.next() + } + } + + private def nextRightRow(): Unit = { + if (rightIndex < rightMatches.length) { + rightCurrentRow = rightMatchesIterator.next() + } + } + // --- Public methods -------------------------------------------------------------------------- def getJoinedRow(): JoinedRow = joinedRow def advanceNext(): Boolean = { // If we already buffered some matching rows, use them directly - if (leftIndex <= leftMatches.size || rightIndex <= rightMatches.size) { + if (leftIndex <= leftMatches.length || rightIndex <= rightMatches.length) { if (scanNextInBuffered()) { return true } @@ -1158,6 +1199,7 @@ private class SortMergeFullOuterJoinScanner( true } else { // Both iterators have been consumed + eagerCleanupResources() false } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index fe6775cc7f9b..2cb8e4b0b400 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -791,9 +791,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan ) } - // FULL OUTER JOIN still does not use [[ExternalAppendOnlyUnsafeRowArray]] - // so should not cause any spill - assertNotSpilled(sparkContext, "full outer join") { + assertSpilled(sparkContext, "full outer join") { checkAnswer( sql( """