Skip to content

Commit abf7bfe

Browse files
committed
Add basic test case.
1 parent 81d52c5 commit abf7bfe

File tree

3 files changed

+157
-30
lines changed

3 files changed

+157
-30
lines changed

core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSortDataFormat.java

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.unsafe.sort;
1919

20+
import static org.apache.spark.unsafe.sort.UnsafeSorter.KeyPointerAndPrefix;
2021
import org.apache.spark.util.collection.SortDataFormat;
2122

2223
/**
@@ -26,24 +27,11 @@
2627
* index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix.
2728
*/
2829
final class UnsafeSortDataFormat
29-
extends SortDataFormat<UnsafeSortDataFormat.KeyPointerAndPrefix, long[]> {
30+
extends SortDataFormat<KeyPointerAndPrefix, long[]> {
3031

3132
public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat();
3233

33-
private UnsafeSortDataFormat() { };
34-
35-
public static final class KeyPointerAndPrefix {
36-
/**
37-
* A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a
38-
* description of how these addresses are encoded.
39-
*/
40-
long recordPointer;
41-
42-
/**
43-
* A key prefix, for use in comparisons.
44-
*/
45-
long keyPrefix;
46-
}
34+
private UnsafeSortDataFormat() { }
4735

4836
@Override
4937
public KeyPointerAndPrefix getKey(long[] data, int pos) {

core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,24 @@
2020
import java.util.Comparator;
2121
import java.util.Iterator;
2222

23-
import org.apache.spark.unsafe.memory.MemoryLocation;
2423
import org.apache.spark.util.collection.Sorter;
2524
import org.apache.spark.unsafe.memory.TaskMemoryManager;
26-
import static org.apache.spark.unsafe.sort.UnsafeSortDataFormat.KeyPointerAndPrefix;
2725

2826
public final class UnsafeSorter {
2927

28+
public static final class KeyPointerAndPrefix {
29+
/**
30+
* A pointer to a record; see {@link org.apache.spark.unsafe.memory.TaskMemoryManager} for a
31+
* description of how these addresses are encoded.
32+
*/
33+
long recordPointer;
34+
35+
/**
36+
* A key prefix, for use in comparisons.
37+
*/
38+
long keyPrefix;
39+
}
40+
3041
public static abstract class RecordComparator {
3142
public abstract int compare(
3243
Object leftBaseObject,
@@ -105,25 +116,23 @@ public void insertRecord(long objectAddress) {
105116
sortBufferInsertPosition += 2;
106117
}
107118

108-
public Iterator<MemoryLocation> getSortedIterator() {
109-
final MemoryLocation memoryLocation = new MemoryLocation();
119+
public Iterator<KeyPointerAndPrefix> getSortedIterator() {
110120
sorter.sort(sortBuffer, 0, sortBufferInsertPosition, sortComparator);
111-
return new Iterator<MemoryLocation>() {
112-
int position = 0;
121+
return new Iterator<KeyPointerAndPrefix>() {
122+
private int position = 0;
123+
private final KeyPointerAndPrefix keyPointerAndPrefix = new KeyPointerAndPrefix();
113124

114125
@Override
115126
public boolean hasNext() {
116127
return position < sortBufferInsertPosition;
117128
}
118129

119130
@Override
120-
public MemoryLocation next() {
121-
final long address = sortBuffer[position];
131+
public KeyPointerAndPrefix next() {
132+
keyPointerAndPrefix.recordPointer = sortBuffer[position];
133+
keyPointerAndPrefix.keyPrefix = sortBuffer[position + 1];
122134
position += 2;
123-
final Object baseObject = memoryManager.getPage(address);
124-
final long baseOffset = memoryManager.getOffsetInPage(address);
125-
memoryLocation.setObjAndOffset(baseObject, baseOffset);
126-
return memoryLocation;
135+
return keyPointerAndPrefix;
127136
}
128137

129138
@Override
Lines changed: 133 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,137 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
118
package org.apache.spark.unsafe.sort;
219

3-
/**
4-
* Created by joshrosen on 4/29/15.
5-
*/
20+
import java.util.Arrays;
21+
import java.util.Iterator;
22+
23+
import org.junit.Assert;
24+
import org.junit.Test;
25+
26+
import org.apache.spark.HashPartitioner;
27+
import org.apache.spark.unsafe.PlatformDependent;
28+
import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
29+
import org.apache.spark.unsafe.memory.MemoryAllocator;
30+
import org.apache.spark.unsafe.memory.MemoryBlock;
31+
import org.apache.spark.unsafe.memory.TaskMemoryManager;
32+
633
public class UnsafeSorterSuite {
34+
35+
private static String getStringFromDataPage(Object baseObject, long baseOffset) {
36+
final int strLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset);
37+
final byte[] strBytes = new byte[strLength];
38+
PlatformDependent.UNSAFE.copyMemory(
39+
baseObject,
40+
baseOffset + 8,
41+
strBytes,
42+
PlatformDependent.BYTE_ARRAY_OFFSET, strLength);
43+
return new String(strBytes);
44+
}
45+
46+
/**
47+
* Tests the type of sorting that's used in the non-combiner path of sort-based shuffle.
48+
*/
49+
@Test
50+
public void testSortingOnlyByPartitionId() throws Exception {
51+
final String[] dataToSort = new String[] {
52+
"Boba",
53+
"Pearls",
54+
"Tapioca",
55+
"Taho",
56+
"Condensed Milk",
57+
"Jasmine",
58+
"Milk Tea",
59+
"Lychee",
60+
"Mango"
61+
};
62+
final TaskMemoryManager memoryManager =
63+
new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
64+
final MemoryBlock dataPage = memoryManager.allocatePage(2048);
65+
final Object baseObject = dataPage.getBaseObject();
66+
// Write the records into the data page:
67+
long position = dataPage.getBaseOffset();
68+
for (String str : dataToSort) {
69+
final byte[] strBytes = str.getBytes("utf-8");
70+
PlatformDependent.UNSAFE.putLong(baseObject, position, strBytes.length);
71+
position += 8;
72+
PlatformDependent.copyMemory(
73+
strBytes,
74+
PlatformDependent.BYTE_ARRAY_OFFSET,
75+
baseObject,
76+
position,
77+
strBytes.length);
78+
position += strBytes.length;
79+
}
80+
// Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so
81+
// use a dummy comparator
82+
final UnsafeSorter.RecordComparator recordComparator = new UnsafeSorter.RecordComparator() {
83+
@Override
84+
public int compare(
85+
Object leftBaseObject,
86+
long leftBaseOffset,
87+
Object rightBaseObject,
88+
long rightBaseOffset) {
89+
return 0;
90+
}
91+
};
92+
// Compute key prefixes based on the records' partition ids
93+
final HashPartitioner hashPartitioner = new HashPartitioner(4);
94+
final UnsafeSorter.PrefixComputer prefixComputer = new UnsafeSorter.PrefixComputer() {
95+
@Override
96+
public long computePrefix(Object baseObject, long baseOffset) {
97+
final String str = getStringFromDataPage(baseObject, baseOffset);
98+
final int partitionId = hashPartitioner.getPartition(str);
99+
return (long) partitionId;
100+
}
101+
};
102+
// Use integer comparison for comparing prefixes (which are partition ids, in this case)
103+
final UnsafeSorter.PrefixComparator prefixComparator = new UnsafeSorter.PrefixComparator() {
104+
@Override
105+
public int compare(long prefix1, long prefix2) {
106+
return (int) prefix1 - (int) prefix2;
107+
}
108+
};
109+
final UnsafeSorter sorter =
110+
new UnsafeSorter(memoryManager, recordComparator, prefixComputer, prefixComparator);
111+
// Given a page of records, insert those records into the sorter one-by-one:
112+
position = dataPage.getBaseOffset();
113+
for (int i = 0; i < dataToSort.length; i++) {
114+
// position now points to the start of a record (which holds its length).
115+
final long recordLength = PlatformDependent.UNSAFE.getLong(baseObject, position);
116+
final long address = memoryManager.encodePageNumberAndOffset(dataPage, position);
117+
sorter.insertRecord(address);
118+
position += 8 + recordLength;
119+
}
120+
final Iterator<UnsafeSorter.KeyPointerAndPrefix> iter = sorter.getSortedIterator();
121+
int iterLength = 0;
122+
long prevPrefix = -1;
123+
Arrays.sort(dataToSort);
124+
while (iter.hasNext()) {
125+
final UnsafeSorter.KeyPointerAndPrefix pointerAndPrefix = iter.next();
126+
final Object recordBaseObject = memoryManager.getPage(pointerAndPrefix.recordPointer);
127+
final long recordBaseOffset = memoryManager.getOffsetInPage(pointerAndPrefix.recordPointer);
128+
final String str = getStringFromDataPage(recordBaseObject, recordBaseOffset);
129+
Assert.assertTrue("String should be valid", Arrays.binarySearch(dataToSort, str) != -1);
130+
Assert.assertTrue("Prefix " + pointerAndPrefix.keyPrefix + " should be >= previous prefix " +
131+
prevPrefix, pointerAndPrefix.keyPrefix >= prevPrefix);
132+
prevPrefix = pointerAndPrefix.keyPrefix;
133+
iterLength++;
134+
}
135+
Assert.assertEquals(dataToSort.length, iterLength);
136+
}
7137
}

0 commit comments

Comments
 (0)