diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala similarity index 100% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/execution/RowIterator.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 3123f2187da83..90b55a8586de7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -20,7 +20,6 @@ import java.io.IOException; import java.util.function.Supplier; -import scala.collection.AbstractIterator; import scala.collection.Iterator; import scala.math.Ordering; @@ -168,7 +167,7 @@ public void cleanupResources() { sorter.cleanupResources(); } - public Iterator sort() throws IOException { + public Iterator sort() throws IOException { try { final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator(); if (!sortedIterator.hasNext()) { @@ -176,31 +175,32 @@ public Iterator sort() throws IOException { // here in order to prevent memory leaks. cleanupResources(); } - return new AbstractIterator() { + return new RowIterator() { private final int numFields = schema.length(); private UnsafeRow row = new UnsafeRow(numFields); @Override - public boolean hasNext() { - return !isReleased && sortedIterator.hasNext(); - } - - @Override - public UnsafeRow next() { + public boolean advanceNext() { try { - sortedIterator.loadNext(); - row.pointTo( - sortedIterator.getBaseObject(), - sortedIterator.getBaseOffset(), - sortedIterator.getRecordLength()); - if (!hasNext()) { - UnsafeRow copy = row.copy(); // so that we don't have dangling pointers to freed page - row = null; // so that we don't keep references to the base object - cleanupResources(); - return copy; + if (!isReleased && sortedIterator.hasNext()) { + sortedIterator.loadNext(); + row.pointTo( + sortedIterator.getBaseObject(), + sortedIterator.getBaseOffset(), + sortedIterator.getRecordLength()); + // Here is the initial bug fix in SPARK-9364: the bug fix of use-after-free bug + // when returning the last row from an iterator. For example, in + // [[GroupedIterator]], we still use the last row after traversing the iterator + // in `fetchNextGroupIterator` + if (!sortedIterator.hasNext()) { + row = row.copy(); // so that we don't have dangling pointers to freed page + cleanupResources(); + } + return true; } else { - return row; + row = null; // so that we don't keep references to the base object + return false; } } catch (IOException e) { cleanupResources(); @@ -210,14 +210,18 @@ public UnsafeRow next() { } throw new RuntimeException("Exception should have been re-thrown in next()"); } - }; + + @Override + public UnsafeRow getRow() { return row; } + + }.toScala(); } catch (IOException e) { cleanupResources(); throw e; } } - public Iterator sort(Iterator inputIterator) throws IOException { + public Iterator sort(Iterator inputIterator) throws IOException { while (inputIterator.hasNext()) { insertRow(inputIterator.next()); }