Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,8 @@ public abstract int compare(
long leftBaseOffset,
Object rightBaseObject,
long rightBaseOffset);

public interface Factory {
RecordComparator create();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
@Nullable
private final PrefixComparator prefixComparator;
@Nullable
private final RecordComparator recordComparator;
private final RecordComparator.Factory recordComparatorFactory;
private final TaskMemoryManager taskMemoryManager;
private final BlockManager blockManager;
private final SerializerManager serializerManager;
Expand Down Expand Up @@ -91,14 +91,14 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter(
BlockManager blockManager,
SerializerManager serializerManager,
TaskContext taskContext,
RecordComparator recordComparator,
RecordComparator.Factory recordComparatorFactory,
PrefixComparator prefixComparator,
int initialSize,
long pageSizeBytes,
long numElementsForSpillThreshold,
UnsafeInMemorySorter inMemorySorter) throws IOException {
UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager,
serializerManager, taskContext, recordComparator, prefixComparator, initialSize,
serializerManager, taskContext, recordComparatorFactory, prefixComparator, initialSize,
numElementsForSpillThreshold, pageSizeBytes, inMemorySorter, false /* ignored */);
sorter.spill(Long.MAX_VALUE, sorter);
// The external sorter will be used to insert records, in-memory sorter is not needed.
Expand All @@ -111,14 +111,14 @@ public static UnsafeExternalSorter create(
BlockManager blockManager,
SerializerManager serializerManager,
TaskContext taskContext,
RecordComparator recordComparator,
RecordComparator.Factory recordComparatorFactory,
PrefixComparator prefixComparator,
int initialSize,
long pageSizeBytes,
long numElementsForSpillThreshold,
boolean canUseRadixSort) {
return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager,
taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes,
taskContext, recordComparatorFactory, prefixComparator, initialSize, pageSizeBytes,
numElementsForSpillThreshold, null, canUseRadixSort);
}

Expand All @@ -127,7 +127,7 @@ private UnsafeExternalSorter(
BlockManager blockManager,
SerializerManager serializerManager,
TaskContext taskContext,
RecordComparator recordComparator,
RecordComparator.Factory recordComparatorFactory,
PrefixComparator prefixComparator,
int initialSize,
long pageSizeBytes,
Expand All @@ -139,7 +139,7 @@ private UnsafeExternalSorter(
this.blockManager = blockManager;
this.serializerManager = serializerManager;
this.taskContext = taskContext;
this.recordComparator = recordComparator;
this.recordComparatorFactory = recordComparatorFactory;
this.prefixComparator = prefixComparator;
// Use getSizeAsKb (not bytes) to maintain backwards compatibility for units
// this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024
Expand All @@ -151,7 +151,7 @@ private UnsafeExternalSorter(

if (existingInMemorySorter == null) {
this.inMemSorter = new UnsafeInMemorySorter(
this, taskMemoryManager, recordComparator, prefixComparator, initialSize, canUseRadixSort);
this, taskMemoryManager, recordComparatorFactory, prefixComparator, initialSize, canUseRadixSort);
} else {
this.inMemSorter = existingInMemorySorter;
}
Expand Down Expand Up @@ -454,14 +454,14 @@ public void merge(UnsafeExternalSorter other) throws IOException {
* after consuming this iterator.
*/
public UnsafeSorterIterator getSortedIterator() throws IOException {
assert(recordComparator != null);
assert(recordComparatorFactory != null);
if (spillWriters.isEmpty()) {
assert(inMemSorter != null);
readingIterator = new SpillableIterator(inMemSorter.getSortedIterator());
return readingIterator;
} else {
final UnsafeSorterSpillMerger spillMerger =
new UnsafeSorterSpillMerger(recordComparator, prefixComparator, spillWriters.size());
new UnsafeSorterSpillMerger(recordComparatorFactory.create(), prefixComparator, spillWriters.size());
for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
spillMerger.addSpillIfNotEmpty(spillWriter.getReader(serializerManager));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,21 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) {

private final MemoryConsumer consumer;
private final TaskMemoryManager memoryManager;

private final PrefixComparator prefixComparator;

@Nullable
private final Comparator<RecordPointerAndKeyPrefix> sortComparator;
private final RecordComparator.Factory recordComparatorFactory;

/**
* If non-null, specifies the radix sort parameters and that radix sort will be used.
*/
@Nullable
private final PrefixComparators.RadixSortSupport radixSortSupport;

@Nullable
private Comparator<RecordPointerAndKeyPrefix> sortComparator;

/**
* Within this buffer, position {@code 2 * i} holds a pointer to the record at
* index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
Expand Down Expand Up @@ -118,26 +124,28 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) {
public UnsafeInMemorySorter(
final MemoryConsumer consumer,
final TaskMemoryManager memoryManager,
final RecordComparator recordComparator,
final RecordComparator.Factory recordComparatorFactory,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to change UnsafeInMemorySorter. The TaskContext refers UnsafeExternalSorter, so we only need the comparator factory in UnsafeExternalSorter.

final PrefixComparator prefixComparator,
int initialSize,
boolean canUseRadixSort) {
this(consumer, memoryManager, recordComparator, prefixComparator,
this(consumer, memoryManager, recordComparatorFactory, prefixComparator,
consumer.allocateArray(initialSize * 2), canUseRadixSort);
}

public UnsafeInMemorySorter(
final MemoryConsumer consumer,
final TaskMemoryManager memoryManager,
final RecordComparator recordComparator,
final RecordComparator.Factory recordComparatorFactory,
final PrefixComparator prefixComparator,
LongArray array,
boolean canUseRadixSort) {
this.consumer = consumer;
this.memoryManager = memoryManager;
this.initialSize = array.size();
if (recordComparator != null) {
this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
this.prefixComparator = prefixComparator;
this.recordComparatorFactory = recordComparatorFactory;
if (recordComparatorFactory != null) {
this.sortComparator = createNewSortComparator();
if (canUseRadixSort && prefixComparator instanceof PrefixComparators.RadixSortSupport) {
this.radixSortSupport = (PrefixComparators.RadixSortSupport)prefixComparator;
} else {
Expand All @@ -151,6 +159,10 @@ public UnsafeInMemorySorter(
this.usableCapacity = getUsableCapacity();
}

private SortComparator createNewSortComparator() {
return new SortComparator(recordComparatorFactory.create(), prefixComparator, memoryManager);
}

private int getUsableCapacity() {
// Radix sort requires same amount of used memory as buffer, Tim sort requires
// half of the used memory as buffer.
Expand All @@ -175,6 +187,9 @@ public void reset() {
}
pos = 0;
nullBoundaryPos = 0;
if (sortComparator != null) {
sortComparator = createNewSortComparator();
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,15 @@ public class UnsafeExternalSorterSuite {
final PrefixComparator prefixComparator = PrefixComparators.LONG;
// Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so
// use a dummy comparator
final RecordComparator recordComparator = new RecordComparator() {
final RecordComparator.Factory recordComparatorFactory = new RecordComparator.Factory() {
@Override
public int compare(
Object leftBaseObject,
long leftBaseOffset,
Object rightBaseObject,
long rightBaseOffset) {
return 0;
public RecordComparator create() {
return new RecordComparator() {
@Override
public int compare(Object leftBaseObject, long leftBaseOffset, Object rightBaseObject, long rightBaseOffset) {
return 0;
}
};
}
};

Expand Down Expand Up @@ -154,7 +155,7 @@ private UnsafeExternalSorter newSorter() throws IOException {
blockManager,
serializerManager,
taskContext,
recordComparator,
recordComparatorFactory,
prefixComparator,
/* initialSize */ 1024,
pageSizeBytes,
Expand Down Expand Up @@ -415,7 +416,7 @@ public void testPeakMemoryUsed() throws Exception {
blockManager,
serializerManager,
taskContext,
recordComparator,
recordComparatorFactory,
prefixComparator,
1024,
pageSizeBytes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public void testSortingEmptyInput() {
final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager);
final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer,
memoryManager,
mock(RecordComparator.class),
mock(RecordComparator.Factory.class),
mock(PrefixComparator.class),
100,
shouldUseRadixSort());
Expand Down Expand Up @@ -92,14 +92,15 @@ public void testSortingOnlyByIntegerPrefix() throws Exception {
}
// Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so
// use a dummy comparator
final RecordComparator recordComparator = new RecordComparator() {
final RecordComparator.Factory recordComparator = new RecordComparator.Factory() {
@Override
public int compare(
Object leftBaseObject,
long leftBaseOffset,
Object rightBaseObject,
long rightBaseOffset) {
return 0;
public RecordComparator create() {
return new RecordComparator() {
@Override
public int compare(Object leftBaseObject, long leftBaseOffset, Object rightBaseObject, long rightBaseOffset) {
return 0;
}
};
}
};
// Compute key prefixes based on the records' partition ids
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public UnsafeExternalRowSorter(
sparkEnv.blockManager(),
sparkEnv.serializerManager(),
taskContext,
new RowComparator(ordering, schema.length()),
new RowComparatorFactory(ordering, schema.length()),
prefixComparator,
sparkEnv.conf().getInt("spark.shuffle.sort.initialBufferSize",
DEFAULT_INITIAL_SORT_BUFFER_SIZE),
Expand Down Expand Up @@ -193,6 +193,21 @@ public Iterator<UnsafeRow> sort(Iterator<UnsafeRow> inputIterator) throws IOExce
return sort();
}

private static final class RowComparatorFactory implements RecordComparator.Factory {
private final Ordering<InternalRow> ordering;
private final int numFields;

private RowComparatorFactory(Ordering<InternalRow> ordering, int numFields) {
this.ordering = ordering;
this.numFields = numFields;
}

@Override
public RecordComparator create() {
return new RowComparator(ordering, numFields);
}
}

private static final class RowComparator extends RecordComparator {
private final Ordering<InternalRow> ordering;
private final int numFields;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public UnsafeKVExternalSorter(
prefixComputer = SortPrefixUtils.createPrefixGenerator(keySchema);
PrefixComparator prefixComparator = SortPrefixUtils.getPrefixComparator(keySchema);
BaseOrdering ordering = GenerateOrdering.create(keySchema);
KVComparator recordComparator = new KVComparator(ordering, keySchema.length());
KVComparatorFactory recordComparatorFactory = new KVComparatorFactory(ordering, keySchema.length());
boolean canUseRadixSort = keySchema.length() == 1 &&
SortPrefixUtils.canSortFullyWithPrefix(keySchema.apply(0));

Expand All @@ -88,7 +88,7 @@ public UnsafeKVExternalSorter(
blockManager,
serializerManager,
taskContext,
recordComparator,
recordComparatorFactory,
prefixComparator,
SparkEnv.get().conf().getInt("spark.shuffle.sort.initialBufferSize",
UnsafeExternalRowSorter.DEFAULT_INITIAL_SORT_BUFFER_SIZE),
Expand All @@ -104,7 +104,7 @@ public UnsafeKVExternalSorter(
// as the underlying array for in-memory sorter (it's always large enough).
// Since we will not grow the array, it's fine to pass `null` as consumer.
final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter(
null, taskMemoryManager, recordComparator, prefixComparator, map.getArray(),
null, taskMemoryManager, recordComparatorFactory, prefixComparator, map.getArray(),
canUseRadixSort);

// We cannot use the destructive iterator here because we are reusing the existing memory
Expand Down Expand Up @@ -137,7 +137,7 @@ public UnsafeKVExternalSorter(
blockManager,
serializerManager,
taskContext,
new KVComparator(ordering, keySchema.length()),
recordComparatorFactory,
prefixComparator,
SparkEnv.get().conf().getInt("spark.shuffle.sort.initialBufferSize",
UnsafeExternalRowSorter.DEFAULT_INITIAL_SORT_BUFFER_SIZE),
Expand Down Expand Up @@ -223,6 +223,21 @@ public void cleanupResources() {
sorter.cleanupResources();
}

private static final class KVComparatorFactory implements RecordComparator.Factory {
private final BaseOrdering ordering;
private final int numKeyFields;

private KVComparatorFactory(BaseOrdering ordering, int numKeyFields) {
this.ordering = ordering;
this.numKeyFields = numKeyFields;
}

@Override
public RecordComparator create() {
return new KVComparator(ordering, numKeyFields);
}
}

private static final class KVComparator extends RecordComparator {
private final BaseOrdering ordering;
private final UnsafeRow row1;
Expand Down