diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 639b8e00c121..799b2fed81cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -179,6 +179,7 @@ case class SortMergeJoinExec( currentRightMatches = null currentLeftRow = null rightMatchesIterator = null + smjScanner.destruct() return false } } @@ -188,6 +189,7 @@ case class SortMergeJoinExec( return true } } + smjScanner.destruct() false } @@ -266,6 +268,7 @@ case class SortMergeJoinExec( } } } + smjScanner.destruct() false } @@ -306,6 +309,7 @@ case class SortMergeJoinExec( return true } } + smjScanner.destruct() false } @@ -344,6 +348,7 @@ case class SortMergeJoinExec( numOutputRows += 1 return true } + smjScanner.destruct() false } @@ -604,6 +609,12 @@ case class SortMergeJoinExec( | } | if (shouldStop()) return; |} + |while ($leftInput.hasNext()) { + | $leftInput.next(); + |} + while ($rightInput.hasNext()) { + | $rightInput.next(); + |} """.stripMargin } } @@ -649,6 +660,11 @@ private[joins] class SortMergeJoinScanner( // Initialization (note: do _not_ want to advance streamed here). advancedBufferedToRowWithNullFreeJoinKey() + def destruct(): Unit = { + while (streamedIter.advanceNext()) {} + while (bufferedIter.advanceNext()) {} + } + // --- Public methods --------------------------------------------------------------------------- def getStreamedRow: InternalRow = streamedRow @@ -915,7 +931,11 @@ private abstract class OneSideOuterIterator( override def advanceNext(): Boolean = { val r = advanceBufferUntilBoundConditionSatisfied() || advanceStream() - if (r) numOutputRows += 1 + if (r) { + numOutputRows += 1 + } else { + smjScanner.destruct() + } r } @@ -947,6 +967,10 @@ private class SortMergeFullOuterJoinScanner( advancedLeft() advancedRight() + def destruct(): Unit = { + while (leftIter.advanceNext()) {} + while (rightIter.advanceNext()) {} + } // --- Private methods -------------------------------------------------------------------------- /** @@ -1103,7 +1127,11 @@ private class FullOuterIterator( override def advanceNext(): Boolean = { val r = smjScanner.advanceNext() - if (r) numRows += 1 + if (r) { + numRows += 1 + } else { + smjScanner.destruct(); + } r }