diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java b/sql/core/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java index 40c2cc806e87..1f243406c77e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java @@ -20,8 +20,13 @@ import org.apache.spark.unsafe.Platform; import org.apache.spark.util.collection.unsafe.sort.RecordComparator; +import java.nio.ByteOrder; + public final class RecordBinaryComparator extends RecordComparator { + private static final boolean LITTLE_ENDIAN = + ByteOrder.nativeOrder().equals(ByteOrder.LITTLE_ENDIAN); + @Override public int compare( Object leftObj, long leftOff, int leftLen, Object rightObj, long rightOff, int rightLen) { @@ -38,10 +43,10 @@ public int compare( // check if stars align and we can get both offsets to be aligned if ((leftOff % 8) == (rightOff % 8)) { while ((leftOff + i) % 8 != 0 && i < leftLen) { - final int v1 = Platform.getByte(leftObj, leftOff + i) & 0xff; - final int v2 = Platform.getByte(rightObj, rightOff + i) & 0xff; + final int v1 = Platform.getByte(leftObj, leftOff + i); + final int v2 = Platform.getByte(rightObj, rightOff + i); if (v1 != v2) { - return v1 > v2 ? 1 : -1; + return (v1 & 0xff) > (v2 & 0xff) ? 1 : -1; } i += 1; } @@ -49,10 +54,17 @@ public int compare( // for architectures that support unaligned accesses, chew it up 8 bytes at a time if (Platform.unaligned() || (((leftOff + i) % 8 == 0) && ((rightOff + i) % 8 == 0))) { while (i <= leftLen - 8) { - final long v1 = Platform.getLong(leftObj, leftOff + i); - final long v2 = Platform.getLong(rightObj, rightOff + i); + long v1 = Platform.getLong(leftObj, leftOff + i); + long v2 = Platform.getLong(rightObj, rightOff + i); if (v1 != v2) { - return v1 > v2 ? 1 : -1; + if (LITTLE_ENDIAN) { + // if read as little-endian, we have to reverse bytes so that the long comparison result + // is equivalent to byte-by-byte comparison result. + // See discussion in https://github.com/apache/spark/pull/26548#issuecomment-554645859 + v1 = Long.reverseBytes(v1); + v2 = Long.reverseBytes(v2); + } + return Long.compareUnsigned(v1, v2); } i += 8; } @@ -60,10 +72,10 @@ public int compare( // this will finish off the unaligned comparisons, or do the entire aligned comparison // whichever is needed. while (i < leftLen) { - final int v1 = Platform.getByte(leftObj, leftOff + i) & 0xff; - final int v2 = Platform.getByte(rightObj, rightOff + i) & 0xff; + final int v1 = Platform.getByte(leftObj, leftOff + i); + final int v2 = Platform.getByte(rightObj, rightOff + i); if (v1 != v2) { - return v1 > v2 ? 1 : -1; + return (v1 & 0xff) > (v2 & 0xff) ? 1 : -1; } i += 1; } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java index 92dabc79d2bf..68f984ae0c1e 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java @@ -33,6 +33,7 @@ import org.apache.spark.util.collection.unsafe.sort.*; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -273,7 +274,7 @@ public void testBinaryComparatorWhenSubtractionIsDivisibleByMaxIntValue() throws insertRow(row1); insertRow(row2); - assert(compare(0, 1) < 0); + assert(compare(0, 1) > 0); } @Test @@ -321,4 +322,48 @@ public void testBinaryComparatorWhenOnlyTheLastColumnDiffers() throws Exception assert(compare(0, 1) < 0); } + + @Test + public void testCompareLongsAsLittleEndian() { + long arrayOffset = 12; + + long[] arr1 = new long[2]; + Platform.putLong(arr1, arrayOffset, 0x0100000000000000L); + long[] arr2 = new long[2]; + Platform.putLong(arr2, arrayOffset + 4, 0x0000000000000001L); + // leftBaseOffset is not aligned while rightBaseOffset is aligned, + // it will start by comparing long + int result1 = binaryComparator.compare(arr1, arrayOffset, 8, arr2, arrayOffset + 4, 8); + + long[] arr3 = new long[2]; + Platform.putLong(arr3, arrayOffset, 0x0100000000000000L); + long[] arr4 = new long[2]; + Platform.putLong(arr4, arrayOffset, 0x0000000000000001L); + // both left and right offset is not aligned, it will start with byte-by-byte comparison + int result2 = binaryComparator.compare(arr3, arrayOffset, 8, arr4, arrayOffset, 8); + + Assert.assertEquals(result1, result2); + } + + @Test + public void testCompareLongsAsUnsigned() { + long arrayOffset = 12; + + long[] arr1 = new long[2]; + Platform.putLong(arr1, arrayOffset + 4, 0xa000000000000000L); + long[] arr2 = new long[2]; + Platform.putLong(arr2, arrayOffset + 4, 0x0000000000000000L); + // both leftBaseOffset and rightBaseOffset are aligned, so it will start by comparing long + int result1 = binaryComparator.compare(arr1, arrayOffset + 4, 8, arr2, arrayOffset + 4, 8); + + long[] arr3 = new long[2]; + Platform.putLong(arr3, arrayOffset, 0xa000000000000000L); + long[] arr4 = new long[2]; + Platform.putLong(arr4, arrayOffset, 0x0000000000000000L); + // both leftBaseOffset and rightBaseOffset are not aligned, + // so it will start with byte-by-byte comparison + int result2 = binaryComparator.compare(arr3, arrayOffset, 8, arr4, arrayOffset, 8); + + Assert.assertEquals(result1, result2); + } }