diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index a6e858ca7202..e2059cec132d 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.util.LinkedList; import java.util.Queue; +import java.util.function.Supplier; import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; @@ -48,8 +49,16 @@ public final class UnsafeExternalSorter extends MemoryConsumer { @Nullable private final PrefixComparator prefixComparator; + + /** + * {@link RecordComparator} may probably keep the reference to the records they compared last + * time, so we should not keep a {@link RecordComparator} instance inside + * {@link UnsafeExternalSorter}, because {@link UnsafeExternalSorter} is referenced by + * {@link TaskContext} and thus can not be garbage collected until the end of the task. + */ @Nullable - private final RecordComparator recordComparator; + private final Supplier recordComparatorSupplier; + private final TaskMemoryManager taskMemoryManager; private final BlockManager blockManager; private final SerializerManager serializerManager; @@ -90,14 +99,14 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter( BlockManager blockManager, SerializerManager serializerManager, TaskContext taskContext, - RecordComparator recordComparator, + Supplier recordComparatorSupplier, PrefixComparator prefixComparator, int initialSize, long pageSizeBytes, long numElementsForSpillThreshold, UnsafeInMemorySorter inMemorySorter) throws IOException { UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager, - serializerManager, taskContext, recordComparator, prefixComparator, initialSize, + serializerManager, taskContext, recordComparatorSupplier, prefixComparator, initialSize, numElementsForSpillThreshold, pageSizeBytes, inMemorySorter, false /* ignored */); sorter.spill(Long.MAX_VALUE, sorter); // The external sorter will be used to insert records, in-memory sorter is not needed. @@ -110,14 +119,14 @@ public static UnsafeExternalSorter create( BlockManager blockManager, SerializerManager serializerManager, TaskContext taskContext, - RecordComparator recordComparator, + Supplier recordComparatorSupplier, PrefixComparator prefixComparator, int initialSize, long pageSizeBytes, long numElementsForSpillThreshold, boolean canUseRadixSort) { return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager, - taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, + taskContext, recordComparatorSupplier, prefixComparator, initialSize, pageSizeBytes, numElementsForSpillThreshold, null, canUseRadixSort); } @@ -126,7 +135,7 @@ private UnsafeExternalSorter( BlockManager blockManager, SerializerManager serializerManager, TaskContext taskContext, - RecordComparator recordComparator, + Supplier recordComparatorSupplier, PrefixComparator prefixComparator, int initialSize, long pageSizeBytes, @@ -138,15 +147,24 @@ private UnsafeExternalSorter( this.blockManager = blockManager; this.serializerManager = serializerManager; this.taskContext = taskContext; - this.recordComparator = recordComparator; + this.recordComparatorSupplier = recordComparatorSupplier; this.prefixComparator = prefixComparator; // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024 this.fileBufferSizeBytes = 32 * 1024; if (existingInMemorySorter == null) { + RecordComparator comparator = null; + if (recordComparatorSupplier != null) { + comparator = recordComparatorSupplier.get(); + } this.inMemSorter = new UnsafeInMemorySorter( - this, taskMemoryManager, recordComparator, prefixComparator, initialSize, canUseRadixSort); + this, + taskMemoryManager, + comparator, + prefixComparator, + initialSize, + canUseRadixSort); } else { this.inMemSorter = existingInMemorySorter; } @@ -451,14 +469,14 @@ public void merge(UnsafeExternalSorter other) throws IOException { * after consuming this iterator. */ public UnsafeSorterIterator getSortedIterator() throws IOException { - assert(recordComparator != null); + assert(recordComparatorSupplier != null); if (spillWriters.isEmpty()) { assert(inMemSorter != null); readingIterator = new SpillableIterator(inMemSorter.getSortedIterator()); return readingIterator; } else { - final UnsafeSorterSpillMerger spillMerger = - new UnsafeSorterSpillMerger(recordComparator, prefixComparator, spillWriters.size()); + final UnsafeSorterSpillMerger spillMerger = new UnsafeSorterSpillMerger( + recordComparatorSupplier.get(), prefixComparator, spillWriters.size()); for (UnsafeSorterSpillWriter spillWriter : spillWriters) { spillMerger.addSpillIfNotEmpty(spillWriter.getReader(serializerManager)); } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index cd5db1a70f72..5330a688e63e 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -154,7 +154,7 @@ private UnsafeExternalSorter newSorter() throws IOException { blockManager, serializerManager, taskContext, - recordComparator, + () -> recordComparator, prefixComparator, /* initialSize */ 1024, pageSizeBytes, @@ -440,7 +440,7 @@ public void testPeakMemoryUsed() throws Exception { blockManager, serializerManager, taskContext, - recordComparator, + () -> recordComparator, prefixComparator, 1024, pageSizeBytes, diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index aadfcaa56cc2..5ea16b737db2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -84,7 +84,7 @@ public UnsafeExternalRowSorter( sparkEnv.blockManager(), sparkEnv.serializerManager(), taskContext, - new RowComparator(ordering, schema.length()), + () -> new RowComparator(ordering, schema.length()), prefixComparator, sparkEnv.conf().getInt("spark.shuffle.sort.initialBufferSize", DEFAULT_INITIAL_SORT_BUFFER_SIZE), @@ -195,12 +195,10 @@ public Iterator sort(Iterator inputIterator) throws IOExce private static final class RowComparator extends RecordComparator { private final Ordering ordering; - private final int numFields; private final UnsafeRow row1; private final UnsafeRow row2; RowComparator(Ordering ordering, int numFields) { - this.numFields = numFields; this.row1 = new UnsafeRow(numFields); this.row2 = new UnsafeRow(numFields); this.ordering = ordering; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index ee5bcfd02c79..376560ca654d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -19,6 +19,7 @@ import javax.annotation.Nullable; import java.io.IOException; +import java.util.function.Supplier; import com.google.common.annotations.VisibleForTesting; @@ -76,7 +77,8 @@ public UnsafeKVExternalSorter( prefixComputer = SortPrefixUtils.createPrefixGenerator(keySchema); PrefixComparator prefixComparator = SortPrefixUtils.getPrefixComparator(keySchema); BaseOrdering ordering = GenerateOrdering.create(keySchema); - KVComparator recordComparator = new KVComparator(ordering, keySchema.length()); + Supplier comparatorSupplier = + () -> new KVComparator(ordering, keySchema.length()); boolean canUseRadixSort = keySchema.length() == 1 && SortPrefixUtils.canSortFullyWithPrefix(keySchema.apply(0)); @@ -88,7 +90,7 @@ public UnsafeKVExternalSorter( blockManager, serializerManager, taskContext, - recordComparator, + comparatorSupplier, prefixComparator, SparkEnv.get().conf().getInt("spark.shuffle.sort.initialBufferSize", UnsafeExternalRowSorter.DEFAULT_INITIAL_SORT_BUFFER_SIZE), @@ -104,7 +106,11 @@ public UnsafeKVExternalSorter( // as the underlying array for in-memory sorter (it's always large enough). // Since we will not grow the array, it's fine to pass `null` as consumer. final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( - null, taskMemoryManager, recordComparator, prefixComparator, map.getArray(), + null, + taskMemoryManager, + comparatorSupplier.get(), + prefixComparator, + map.getArray(), canUseRadixSort); // We cannot use the destructive iterator here because we are reusing the existing memory @@ -137,7 +143,7 @@ public UnsafeKVExternalSorter( blockManager, serializerManager, taskContext, - new KVComparator(ordering, keySchema.length()), + comparatorSupplier, prefixComparator, SparkEnv.get().conf().getInt("spark.shuffle.sort.initialBufferSize", UnsafeExternalRowSorter.DEFAULT_INITIAL_SORT_BUFFER_SIZE), @@ -227,10 +233,8 @@ private static final class KVComparator extends RecordComparator { private final BaseOrdering ordering; private final UnsafeRow row1; private final UnsafeRow row2; - private final int numKeyFields; KVComparator(BaseOrdering ordering, int numKeyFields) { - this.numKeyFields = numKeyFields; this.row1 = new UnsafeRow(numKeyFields); this.row2 = new UnsafeRow(numKeyFields); this.ordering = ordering;