Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.util.collection.unsafe.sort;

import com.google.common.primitives.Ints;

import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.LongArray;

Expand All @@ -40,28 +42,28 @@ public class RadixSort {
* of always copying the data back to position zero for efficiency.
*/
public static int sort(
LongArray array, int numRecords, int startByteIndex, int endByteIndex,
LongArray array, long numRecords, int startByteIndex, int endByteIndex,
boolean desc, boolean signed) {
assert startByteIndex >= 0 : "startByteIndex (" + startByteIndex + ") should >= 0";
assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7";
assert endByteIndex > startByteIndex;
assert numRecords * 2 <= array.size();
int inIndex = 0;
int outIndex = numRecords;
long inIndex = 0;
long outIndex = numRecords;
if (numRecords > 0) {
long[][] counts = getCounts(array, numRecords, startByteIndex, endByteIndex);
for (int i = startByteIndex; i <= endByteIndex; i++) {
if (counts[i] != null) {
sortAtByte(
array, numRecords, counts[i], i, inIndex, outIndex,
desc, signed && i == endByteIndex);
int tmp = inIndex;
long tmp = inIndex;
inIndex = outIndex;
outIndex = tmp;
}
}
}
return inIndex;
return Ints.checkedCast(inIndex);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than pull in a library method, how about require which lets you provide an actual error message?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @rxin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might make sense to add a very simple if check in the beginning of this function. I will do it when I merge this.

}

/**
Expand All @@ -78,14 +80,14 @@ public static int sort(
* @param signed whether this is a signed (two's complement) sort (only applies to last byte).
*/
private static void sortAtByte(
LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex,
LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex,
boolean desc, boolean signed) {
assert counts.length == 256;
long[] offsets = transformCountsToOffsets(
counts, numRecords, array.getBaseOffset() + outIndex * 8, 8, desc, signed);
counts, numRecords, array.getBaseOffset() + outIndex * 8L, 8, desc, signed);
Object baseObject = array.getBaseObject();
long baseOffset = array.getBaseOffset() + inIndex * 8;
long maxOffset = baseOffset + numRecords * 8;
long baseOffset = array.getBaseOffset() + inIndex * 8L;
long maxOffset = baseOffset + numRecords * 8L;
for (long offset = baseOffset; offset < maxOffset; offset += 8) {
long value = Platform.getLong(baseObject, offset);
int bucket = (int)((value >>> (byteIdx * 8)) & 0xff);
Expand All @@ -106,13 +108,13 @@ private static void sortAtByte(
* significant byte. If the byte does not need sorting the array will be null.
*/
private static long[][] getCounts(
LongArray array, int numRecords, int startByteIndex, int endByteIndex) {
LongArray array, long numRecords, int startByteIndex, int endByteIndex) {
long[][] counts = new long[8][];
// Optimization: do a fast pre-pass to determine which byte indices we can skip for sorting.
// If all the byte values at a particular index are the same we don't need to count it.
long bitwiseMax = 0;
long bitwiseMin = -1L;
long maxOffset = array.getBaseOffset() + numRecords * 8;
long maxOffset = array.getBaseOffset() + numRecords * 8L;
Object baseObject = array.getBaseObject();
for (long offset = array.getBaseOffset(); offset < maxOffset; offset += 8) {
long value = Platform.getLong(baseObject, offset);
Expand Down Expand Up @@ -146,18 +148,18 @@ private static long[][] getCounts(
* @return the input counts array.
*/
private static long[] transformCountsToOffsets(
long[] counts, int numRecords, long outputOffset, int bytesPerRecord,
long[] counts, long numRecords, long outputOffset, long bytesPerRecord,
boolean desc, boolean signed) {
assert counts.length == 256;
int start = signed ? 128 : 0; // output the negative records first (values 129-255).
if (desc) {
int pos = numRecords;
long pos = numRecords;
for (int i = start; i < start + 256; i++) {
pos -= counts[i & 0xff];
counts[i & 0xff] = outputOffset + pos * bytesPerRecord;
}
} else {
int pos = 0;
long pos = 0;
for (int i = start; i < start + 256; i++) {
long tmp = counts[i & 0xff];
counts[i & 0xff] = outputOffset + pos * bytesPerRecord;
Expand All @@ -176,8 +178,8 @@ private static long[] transformCountsToOffsets(
*/
public static int sortKeyPrefixArray(
LongArray array,
int startIndex,
int numRecords,
long startIndex,
long numRecords,
int startByteIndex,
int endByteIndex,
boolean desc,
Expand All @@ -186,8 +188,8 @@ public static int sortKeyPrefixArray(
assert endByteIndex <= 7 : "endByteIndex (" + endByteIndex + ") should <= 7";
assert endByteIndex > startByteIndex;
assert numRecords * 4 <= array.size();
int inIndex = startIndex;
int outIndex = startIndex + numRecords * 2;
long inIndex = startIndex;
long outIndex = startIndex + numRecords * 2L;
if (numRecords > 0) {
long[][] counts = getKeyPrefixArrayCounts(
array, startIndex, numRecords, startByteIndex, endByteIndex);
Expand All @@ -196,21 +198,21 @@ public static int sortKeyPrefixArray(
sortKeyPrefixArrayAtByte(
array, numRecords, counts[i], i, inIndex, outIndex,
desc, signed && i == endByteIndex);
int tmp = inIndex;
long tmp = inIndex;
inIndex = outIndex;
outIndex = tmp;
}
}
}
return inIndex;
return Ints.checkedCast(inIndex);
}

/**
* Specialization of getCounts() for key-prefix arrays. We could probably combine this with
* getCounts with some added parameters but that seems to hurt in benchmarks.
*/
private static long[][] getKeyPrefixArrayCounts(
LongArray array, int startIndex, int numRecords, int startByteIndex, int endByteIndex) {
LongArray array, long startIndex, long numRecords, int startByteIndex, int endByteIndex) {
long[][] counts = new long[8][];
long bitwiseMax = 0;
long bitwiseMin = -1L;
Expand Down Expand Up @@ -238,11 +240,11 @@ private static long[][] getKeyPrefixArrayCounts(
* Specialization of sortAtByte() for key-prefix arrays.
*/
private static void sortKeyPrefixArrayAtByte(
LongArray array, int numRecords, long[] counts, int byteIdx, int inIndex, int outIndex,
LongArray array, long numRecords, long[] counts, int byteIdx, long inIndex, long outIndex,
boolean desc, boolean signed) {
assert counts.length == 256;
long[] offsets = transformCountsToOffsets(
counts, numRecords, array.getBaseOffset() + outIndex * 8, 16, desc, signed);
counts, numRecords, array.getBaseOffset() + outIndex * 8L, 16, desc, signed);
Object baseObject = array.getBaseObject();
long baseOffset = array.getBaseOffset() + inIndex * 8L;
long maxOffset = baseOffset + numRecords * 16L;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ public UnsafeSorterIterator getSortedIterator() {
if (sortComparator != null) {
if (this.radixSortSupport != null) {
offset = RadixSort.sortKeyPrefixArray(
array, nullBoundaryPos, (pos - nullBoundaryPos) / 2, 0, 7,
array, nullBoundaryPos, (pos - nullBoundaryPos) / 2L, 0, 7,
radixSortSupport.sortDescending(), radixSortSupport.sortSigned());
} else {
MemoryBlock unused = new MemoryBlock(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import java.util.{Arrays, Comparator}

import scala.util.Random

import com.google.common.primitives.Ints

import org.apache.spark.SparkFunSuite
import org.apache.spark.internal.Logging
import org.apache.spark.unsafe.array.LongArray
Expand All @@ -30,7 +32,7 @@ import org.apache.spark.util.collection.Sorter
import org.apache.spark.util.random.XORShiftRandom

class RadixSortSuite extends SparkFunSuite with Logging {
private val N = 10000 // scale this down for more readable results
private val N = 10000L // scale this down for more readable results

/**
* Describes a type of sort to test, e.g. two's complement descending. Each sort type has
Expand Down Expand Up @@ -73,22 +75,22 @@ class RadixSortSuite extends SparkFunSuite with Logging {
},
2, 4, false, false, true))

private def generateTestData(size: Int, rand: => Long): (Array[JLong], LongArray) = {
val ref = Array.tabulate[Long](size) { i => rand }
val extended = ref ++ Array.fill[Long](size)(0)
private def generateTestData(size: Long, rand: => Long): (Array[JLong], LongArray) = {
val ref = Array.tabulate[Long](Ints.checkedCast(size)) { i => rand }
val extended = ref ++ Array.fill[Long](Ints.checkedCast(size))(0)
(ref.map(i => new JLong(i)), new LongArray(MemoryBlock.fromLongArray(extended)))
}

private def generateKeyPrefixTestData(size: Int, rand: => Long): (LongArray, LongArray) = {
val ref = Array.tabulate[Long](size * 2) { i => rand }
val extended = ref ++ Array.fill[Long](size * 2)(0)
private def generateKeyPrefixTestData(size: Long, rand: => Long): (LongArray, LongArray) = {
val ref = Array.tabulate[Long](Ints.checkedCast(size * 2)) { i => rand }
val extended = ref ++ Array.fill[Long](Ints.checkedCast(size * 2))(0)
(new LongArray(MemoryBlock.fromLongArray(ref)),
new LongArray(MemoryBlock.fromLongArray(extended)))
}

private def collectToArray(array: LongArray, offset: Int, length: Int): Array[Long] = {
private def collectToArray(array: LongArray, offset: Int, length: Long): Array[Long] = {
var i = 0
val out = new Array[Long](length)
val out = new Array[Long](Ints.checkedCast(length))
while (i < length) {
out(i) = array.get(offset + i)
i += 1
Expand All @@ -107,15 +109,13 @@ class RadixSortSuite extends SparkFunSuite with Logging {
}
}

private def referenceKeyPrefixSort(buf: LongArray, lo: Int, hi: Int, refCmp: PrefixComparator) {
private def referenceKeyPrefixSort(buf: LongArray, lo: Long, hi: Long, refCmp: PrefixComparator) {
val sortBuffer = new LongArray(MemoryBlock.fromLongArray(new Array[Long](buf.size().toInt)))
new Sorter(new UnsafeSortDataFormat(sortBuffer)).sort(
buf, lo, hi, new Comparator[RecordPointerAndKeyPrefix] {
buf, Ints.checkedCast(lo), Ints.checkedCast(hi), new Comparator[RecordPointerAndKeyPrefix] {
override def compare(
r1: RecordPointerAndKeyPrefix,
r2: RecordPointerAndKeyPrefix): Int = {
refCmp.compare(r1.keyPrefix, r2.keyPrefix)
}
r2: RecordPointerAndKeyPrefix): Int = refCmp.compare(r1.keyPrefix, r2.keyPrefix)
})
}

Expand Down