diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java index 09e425879220..a781759e01e8 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java @@ -34,4 +34,8 @@ public abstract int compare( long leftBaseOffset, Object rightBaseObject, long rightBaseOffset); + + public interface Factory { + RecordComparator create(); + } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index f312fa2b2ddd..fd11cd859cac 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -49,7 +49,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { @Nullable private final PrefixComparator prefixComparator; @Nullable - private final RecordComparator recordComparator; + private final RecordComparator.Factory recordComparatorFactory; private final TaskMemoryManager taskMemoryManager; private final BlockManager blockManager; private final SerializerManager serializerManager; @@ -91,14 +91,14 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter( BlockManager blockManager, SerializerManager serializerManager, TaskContext taskContext, - RecordComparator recordComparator, + RecordComparator.Factory recordComparatorFactory, PrefixComparator prefixComparator, int initialSize, long pageSizeBytes, long numElementsForSpillThreshold, UnsafeInMemorySorter inMemorySorter) throws IOException { UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager, - serializerManager, taskContext, recordComparator, prefixComparator, initialSize, + serializerManager, taskContext, recordComparatorFactory, prefixComparator, initialSize, numElementsForSpillThreshold, pageSizeBytes, inMemorySorter, false /* ignored */); sorter.spill(Long.MAX_VALUE, sorter); // The external sorter will be used to insert records, in-memory sorter is not needed. @@ -111,14 +111,14 @@ public static UnsafeExternalSorter create( BlockManager blockManager, SerializerManager serializerManager, TaskContext taskContext, - RecordComparator recordComparator, + RecordComparator.Factory recordComparatorFactory, PrefixComparator prefixComparator, int initialSize, long pageSizeBytes, long numElementsForSpillThreshold, boolean canUseRadixSort) { return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager, - taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, + taskContext, recordComparatorFactory, prefixComparator, initialSize, pageSizeBytes, numElementsForSpillThreshold, null, canUseRadixSort); } @@ -127,7 +127,7 @@ private UnsafeExternalSorter( BlockManager blockManager, SerializerManager serializerManager, TaskContext taskContext, - RecordComparator recordComparator, + RecordComparator.Factory recordComparatorFactory, PrefixComparator prefixComparator, int initialSize, long pageSizeBytes, @@ -139,7 +139,7 @@ private UnsafeExternalSorter( this.blockManager = blockManager; this.serializerManager = serializerManager; this.taskContext = taskContext; - this.recordComparator = recordComparator; + this.recordComparatorFactory = recordComparatorFactory; this.prefixComparator = prefixComparator; // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024 @@ -151,7 +151,7 @@ private UnsafeExternalSorter( if (existingInMemorySorter == null) { this.inMemSorter = new UnsafeInMemorySorter( - this, taskMemoryManager, recordComparator, prefixComparator, initialSize, canUseRadixSort); + this, taskMemoryManager, recordComparatorFactory, prefixComparator, initialSize, canUseRadixSort); } else { this.inMemSorter = existingInMemorySorter; } @@ -454,14 +454,14 @@ public void merge(UnsafeExternalSorter other) throws IOException { * after consuming this iterator. */ public UnsafeSorterIterator getSortedIterator() throws IOException { - assert(recordComparator != null); + assert(recordComparatorFactory != null); if (spillWriters.isEmpty()) { assert(inMemSorter != null); readingIterator = new SpillableIterator(inMemSorter.getSortedIterator()); return readingIterator; } else { final UnsafeSorterSpillMerger spillMerger = - new UnsafeSorterSpillMerger(recordComparator, prefixComparator, spillWriters.size()); + new UnsafeSorterSpillMerger(recordComparatorFactory.create(), prefixComparator, spillWriters.size()); for (UnsafeSorterSpillWriter spillWriter : spillWriters) { spillMerger.addSpillIfNotEmpty(spillWriter.getReader(serializerManager)); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index c14c12664f5a..d14b926a0da3 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -75,8 +75,11 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { private final MemoryConsumer consumer; private final TaskMemoryManager memoryManager; + + private final PrefixComparator prefixComparator; + @Nullable - private final Comparator sortComparator; + private final RecordComparator.Factory recordComparatorFactory; /** * If non-null, specifies the radix sort parameters and that radix sort will be used. @@ -84,6 +87,9 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { @Nullable private final PrefixComparators.RadixSortSupport radixSortSupport; + @Nullable + private Comparator sortComparator; + /** * Within this buffer, position {@code 2 * i} holds a pointer to the record at * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. @@ -118,26 +124,28 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { public UnsafeInMemorySorter( final MemoryConsumer consumer, final TaskMemoryManager memoryManager, - final RecordComparator recordComparator, + final RecordComparator.Factory recordComparatorFactory, final PrefixComparator prefixComparator, int initialSize, boolean canUseRadixSort) { - this(consumer, memoryManager, recordComparator, prefixComparator, + this(consumer, memoryManager, recordComparatorFactory, prefixComparator, consumer.allocateArray(initialSize * 2), canUseRadixSort); } public UnsafeInMemorySorter( final MemoryConsumer consumer, final TaskMemoryManager memoryManager, - final RecordComparator recordComparator, + final RecordComparator.Factory recordComparatorFactory, final PrefixComparator prefixComparator, LongArray array, boolean canUseRadixSort) { this.consumer = consumer; this.memoryManager = memoryManager; this.initialSize = array.size(); - if (recordComparator != null) { - this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); + this.prefixComparator = prefixComparator; + this.recordComparatorFactory = recordComparatorFactory; + if (recordComparatorFactory != null) { + this.sortComparator = createNewSortComparator(); if (canUseRadixSort && prefixComparator instanceof PrefixComparators.RadixSortSupport) { this.radixSortSupport = (PrefixComparators.RadixSortSupport)prefixComparator; } else { @@ -151,6 +159,10 @@ public UnsafeInMemorySorter( this.usableCapacity = getUsableCapacity(); } + private SortComparator createNewSortComparator() { + return new SortComparator(recordComparatorFactory.create(), prefixComparator, memoryManager); + } + private int getUsableCapacity() { // Radix sort requires same amount of used memory as buffer, Tim sort requires // half of the used memory as buffer. @@ -175,6 +187,9 @@ public void reset() { } pos = 0; nullBoundaryPos = 0; + if (sortComparator != null) { + sortComparator = createNewSortComparator(); + } } /** diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 771d39016c18..8dc48ec1247a 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -65,14 +65,15 @@ public class UnsafeExternalSorterSuite { final PrefixComparator prefixComparator = PrefixComparators.LONG; // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so // use a dummy comparator - final RecordComparator recordComparator = new RecordComparator() { + final RecordComparator.Factory recordComparatorFactory = new RecordComparator.Factory() { @Override - public int compare( - Object leftBaseObject, - long leftBaseOffset, - Object rightBaseObject, - long rightBaseOffset) { - return 0; + public RecordComparator create() { + return new RecordComparator() { + @Override + public int compare(Object leftBaseObject, long leftBaseOffset, Object rightBaseObject, long rightBaseOffset) { + return 0; + } + }; } }; @@ -154,7 +155,7 @@ private UnsafeExternalSorter newSorter() throws IOException { blockManager, serializerManager, taskContext, - recordComparator, + recordComparatorFactory, prefixComparator, /* initialSize */ 1024, pageSizeBytes, @@ -415,7 +416,7 @@ public void testPeakMemoryUsed() throws Exception { blockManager, serializerManager, taskContext, - recordComparator, + recordComparatorFactory, prefixComparator, 1024, pageSizeBytes, diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index bd89085aa9a1..acf66de72f4d 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -54,7 +54,7 @@ public void testSortingEmptyInput() { final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, memoryManager, - mock(RecordComparator.class), + mock(RecordComparator.Factory.class), mock(PrefixComparator.class), 100, shouldUseRadixSort()); @@ -92,14 +92,15 @@ public void testSortingOnlyByIntegerPrefix() throws Exception { } // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so // use a dummy comparator - final RecordComparator recordComparator = new RecordComparator() { + final RecordComparator.Factory recordComparator = new RecordComparator.Factory() { @Override - public int compare( - Object leftBaseObject, - long leftBaseOffset, - Object rightBaseObject, - long rightBaseOffset) { - return 0; + public RecordComparator create() { + return new RecordComparator() { + @Override + public int compare(Object leftBaseObject, long leftBaseOffset, Object rightBaseObject, long rightBaseOffset) { + return 0; + } + }; } }; // Compute key prefixes based on the records' partition ids diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index c29b002a998c..4363d935d7d6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -84,7 +84,7 @@ public UnsafeExternalRowSorter( sparkEnv.blockManager(), sparkEnv.serializerManager(), taskContext, - new RowComparator(ordering, schema.length()), + new RowComparatorFactory(ordering, schema.length()), prefixComparator, sparkEnv.conf().getInt("spark.shuffle.sort.initialBufferSize", DEFAULT_INITIAL_SORT_BUFFER_SIZE), @@ -193,6 +193,21 @@ public Iterator sort(Iterator inputIterator) throws IOExce return sort(); } + private static final class RowComparatorFactory implements RecordComparator.Factory { + private final Ordering ordering; + private final int numFields; + + private RowComparatorFactory(Ordering ordering, int numFields) { + this.ordering = ordering; + this.numFields = numFields; + } + + @Override + public RecordComparator create() { + return new RowComparator(ordering, numFields); + } + } + private static final class RowComparator extends RecordComparator { private final Ordering ordering; private final int numFields; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index ee5bcfd02c79..b06429deb613 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -76,7 +76,7 @@ public UnsafeKVExternalSorter( prefixComputer = SortPrefixUtils.createPrefixGenerator(keySchema); PrefixComparator prefixComparator = SortPrefixUtils.getPrefixComparator(keySchema); BaseOrdering ordering = GenerateOrdering.create(keySchema); - KVComparator recordComparator = new KVComparator(ordering, keySchema.length()); + KVComparatorFactory recordComparatorFactory = new KVComparatorFactory(ordering, keySchema.length()); boolean canUseRadixSort = keySchema.length() == 1 && SortPrefixUtils.canSortFullyWithPrefix(keySchema.apply(0)); @@ -88,7 +88,7 @@ public UnsafeKVExternalSorter( blockManager, serializerManager, taskContext, - recordComparator, + recordComparatorFactory, prefixComparator, SparkEnv.get().conf().getInt("spark.shuffle.sort.initialBufferSize", UnsafeExternalRowSorter.DEFAULT_INITIAL_SORT_BUFFER_SIZE), @@ -104,7 +104,7 @@ public UnsafeKVExternalSorter( // as the underlying array for in-memory sorter (it's always large enough). // Since we will not grow the array, it's fine to pass `null` as consumer. final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( - null, taskMemoryManager, recordComparator, prefixComparator, map.getArray(), + null, taskMemoryManager, recordComparatorFactory, prefixComparator, map.getArray(), canUseRadixSort); // We cannot use the destructive iterator here because we are reusing the existing memory @@ -137,7 +137,7 @@ public UnsafeKVExternalSorter( blockManager, serializerManager, taskContext, - new KVComparator(ordering, keySchema.length()), + recordComparatorFactory, prefixComparator, SparkEnv.get().conf().getInt("spark.shuffle.sort.initialBufferSize", UnsafeExternalRowSorter.DEFAULT_INITIAL_SORT_BUFFER_SIZE), @@ -223,6 +223,21 @@ public void cleanupResources() { sorter.cleanupResources(); } + private static final class KVComparatorFactory implements RecordComparator.Factory { + private final BaseOrdering ordering; + private final int numKeyFields; + + private KVComparatorFactory(BaseOrdering ordering, int numKeyFields) { + this.ordering = ordering; + this.numKeyFields = numKeyFields; + } + + @Override + public RecordComparator create() { + return new KVComparator(ordering, numKeyFields); + } + } + private static final class KVComparator extends RecordComparator { private final BaseOrdering ordering; private final UnsafeRow row1;