diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowIterator.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowIterator.java new file mode 100644 index 000000000000..14fc42d2170c --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowIterator.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution; + +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator; +import scala.collection.AbstractIterator; + +import java.io.Closeable; +import java.io.IOException; + +public abstract class UnsafeExternalRowIterator extends AbstractIterator implements Closeable { + + private final UnsafeSorterIterator sortedIterator; + private UnsafeRow row; + + UnsafeExternalRowIterator(StructType schema, UnsafeSorterIterator iterator) { + row = new UnsafeRow(schema.length()); + sortedIterator = iterator; + } + + @Override + public boolean hasNext() { + return sortedIterator.hasNext(); + } + + @Override + public UnsafeRow next() { + try { + sortedIterator.loadNext(); + row.pointTo( + sortedIterator.getBaseObject(), + sortedIterator.getBaseOffset(), + sortedIterator.getRecordLength()); + if (!hasNext()) { + UnsafeRow copy = row.copy(); // so that we don't have dangling pointers to freed page + row = null; // so that we don't keep references to the base object + close(); + return copy; + } else { + return row; + } + } catch (IOException e) { + close(); + // Scala iterators don't declare any checked exceptions, so we need to use this hack + // to re-throw the exception: + Platform.throwException(e); + } + throw new RuntimeException("Exception should have been re-thrown in next()"); + } + + /** + * Implementation should clean up resources used by this iterator, to prevent memory leaks + */ + @Override + public abstract void close(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index d6edddfc1ae6..6d5d1af21a02 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -20,7 +20,6 @@ import java.io.IOException; import java.util.function.Supplier; -import scala.collection.AbstractIterator; import scala.collection.Iterator; import scala.math.Ordering; @@ -169,39 +168,13 @@ public Iterator sort() throws IOException { // here in order to prevent memory leaks. cleanupResources(); } - return new AbstractIterator() { - - private final int numFields = schema.length(); - private UnsafeRow row = new UnsafeRow(numFields); - - @Override - public boolean hasNext() { - return sortedIterator.hasNext(); - } + return new UnsafeExternalRowIterator(schema, sortedIterator) { @Override - public UnsafeRow next() { - try { - sortedIterator.loadNext(); - row.pointTo( - sortedIterator.getBaseObject(), - sortedIterator.getBaseOffset(), - sortedIterator.getRecordLength()); - if (!hasNext()) { - UnsafeRow copy = row.copy(); // so that we don't have dangling pointers to freed page - row = null; // so that we don't keep references to the base object - cleanupResources(); - return copy; - } else { - return row; - } - } catch (IOException e) { - cleanupResources(); - // Scala iterators don't declare any checked exceptions, so we need to use this hack - // to re-throw the exception: - Platform.throwException(e); - } - throw new RuntimeException("Exception should have been re-thrown in next()"); + public void close() { + // Caller should clean up resources if the iterator is not consumed in it's entirety, + // for example, in a SortMergeJoin. + cleanupResources(); } }; } catch (IOException e) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index f829f07e8072..ac5a01e652ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -195,6 +195,7 @@ case class SortMergeJoinExec( currentRightMatches = null currentLeftRow = null rightMatchesIterator = null + smjScanner.close() return false } } @@ -204,6 +205,7 @@ case class SortMergeJoinExec( return true } } + smjScanner.close() false } @@ -285,6 +287,7 @@ case class SortMergeJoinExec( } } } + smjScanner.close() false } @@ -326,6 +329,7 @@ case class SortMergeJoinExec( return true } } + smjScanner.close() false } @@ -365,6 +369,7 @@ case class SortMergeJoinExec( numOutputRows += 1 return true } + smjScanner.close() false } @@ -669,7 +674,7 @@ private[joins] class SortMergeJoinScanner( streamedIter: RowIterator, bufferedIter: RowIterator, inMemoryThreshold: Int, - spillThreshold: Int) { + spillThreshold: Int) extends CloseableScanner { private[this] var streamedRow: InternalRow = _ private[this] var streamedRowKey: InternalRow = _ private[this] var bufferedRow: InternalRow = _ @@ -785,6 +790,15 @@ private[joins] class SortMergeJoinScanner( } } + /** + * Once the join has been completed, the iterators for both relations + * should be closed, so that acquired memory can be released. + */ + def close(): Unit = { + closeIterator(streamedIter) + closeIterator(bufferedIter) + } + // --- Private methods -------------------------------------------------------------------------- /** @@ -952,7 +966,11 @@ private abstract class OneSideOuterIterator( override def advanceNext(): Boolean = { val r = advanceBufferUntilBoundConditionSatisfied() || advanceStream() - if (r) numOutputRows += 1 + if (r) { + numOutputRows += 1 + } else { + smjScanner.close() + } r } @@ -967,7 +985,7 @@ private class SortMergeFullOuterJoinScanner( rightIter: RowIterator, boundCondition: InternalRow => Boolean, leftNullRow: InternalRow, - rightNullRow: InternalRow) { + rightNullRow: InternalRow) extends CloseableScanner { private[this] val joinedRow: JoinedRow = new JoinedRow() private[this] var leftRow: InternalRow = _ private[this] var leftRowKey: InternalRow = _ @@ -1130,6 +1148,15 @@ private class SortMergeFullOuterJoinScanner( false } } + + /** + * Once the join has been completed, the iterators for both relations + * should be closed, so that acquired memory can be released. + */ + def close(): Unit = { + closeIterator(leftIter) + closeIterator(rightIter) + } } private class FullOuterIterator( @@ -1140,9 +1167,28 @@ private class FullOuterIterator( override def advanceNext(): Boolean = { val r = smjScanner.advanceNext() - if (r) numRows += 1 + if (r) { + numRows += 1 + } else { + smjScanner.close() + } r } override def getRow: InternalRow = resultProj(joinedRow) } + +trait CloseableScanner { + def closeIterator(iter: RowIterator): Unit = { + iter match { + case rowIter: RowIteratorFromScala => + val underlyingIter = rowIter.toScala + underlyingIter match { + case toClose: UnsafeExternalRowIterator => + toClose.close() + case _ => + } + case _ => + } + } +}