Skip to content

Commit b26f1d3

Browse files
committed
Fix bug in murmur hash implementation.
1 parent 765243d commit b26f1d3

File tree

4 files changed

+43
-11
lines changed

4 files changed

+43
-11
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ public void printPerfMetrics() {
242242
throw new IllegalStateException("Perf metrics not enabled");
243243
}
244244
System.out.println("Average probes per lookup: " + map.getAverageProbesPerLookup());
245+
System.out.println("Number of hash collisions: " + map.getNumHashCollisions());
245246
System.out.println("Time spent resizing (ms): " + map.getTimeSpentResizingMs());
246247
System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption());
247248
}

unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,12 @@ public int hashUnsafeWords(Object baseObject, long baseOffset, int lengthInBytes
4848
// See https://code.google.com/p/guava-libraries/source/browse/guava/src/com/google/common/hash/Murmur3_32HashFunction.java#167
4949
// TODO(josh) veryify that this was implemented correctly
5050
assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)";
51-
int k1 = 0;
5251
int h1 = seed;
5352
for (int offset = 0; offset < lengthInBytes; offset += 4) {
5453
int halfWord = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset);
55-
56-
k1 ^= halfWord << offset;
54+
int k1 = mixK1(halfWord);
55+
h1 = mixH1(h1, k1);
5756
}
58-
h1 ^= mixK1(k1);
5957
return fmix(h1, lengthInBytes);
6058
}
6159

unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ public final class BytesToBytesMap {
149149

150150
private long numKeyLookups = 0;
151151

152+
private long numHashCollisions = 0;
153+
152154
public BytesToBytesMap(
153155
MemoryAllocator allocator,
154156
int initialCapacity,
@@ -257,6 +259,10 @@ public Location lookup(
257259
);
258260
if (areEqual) {
259261
return loc;
262+
} else {
263+
if (enablePerfMetrics) {
264+
numHashCollisions++;
265+
}
260266
}
261267
}
262268
}
@@ -532,6 +538,13 @@ public double getAverageProbesPerLookup() {
532538
return (1.0 * numProbes) / numKeyLookups;
533539
}
534540

541+
public long getNumHashCollisions() {
542+
if (!enablePerfMetrics) {
543+
throw new IllegalStateException();
544+
}
545+
return numHashCollisions;
546+
}
547+
535548
/**
536549
* Grows the size of the hash table and re-hash everything.
537550
*/

unsafe/src/test/java/org/apache/spark/unsafe/hash/TestMurmur3_x86_32.java

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,16 +81,36 @@ public void randomizedStressTestBytes() {
8181
int byteArrSize = rand.nextInt(100) * 8;
8282
byte[] bytes = new byte[byteArrSize];
8383
rand.nextBytes(bytes);
84-
long memoryAddr = PlatformDependent.UNSAFE.allocateMemory(byteArrSize);
85-
PlatformDependent.copyMemory(
86-
bytes, PlatformDependent.BYTE_ARRAY_OFFSET, null, memoryAddr, byteArrSize);
8784

8885
Assert.assertEquals(
89-
hasher.hashUnsafeWords(null, memoryAddr, byteArrSize),
90-
hasher.hashUnsafeWords(null, memoryAddr, byteArrSize));
86+
hasher.hashUnsafeWords(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize),
87+
hasher.hashUnsafeWords(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize));
9188

92-
hashcodes.add(hasher.hashUnsafeWords(null, memoryAddr, byteArrSize));
93-
PlatformDependent.UNSAFE.freeMemory(memoryAddr);
89+
hashcodes.add(hasher.hashUnsafeWords(
90+
bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize));
91+
}
92+
93+
// A very loose bound.
94+
Assert.assertTrue(hashcodes.size() > size * 0.95);
95+
}
96+
97+
@Test
98+
public void randomizedStressTestPaddedStrings() {
99+
int size = 64000;
100+
// A set used to track collision rate.
101+
Set<Integer> hashcodes = new HashSet<Integer>();
102+
for (int i = 0; i < size; i++) {
103+
int byteArrSize = 8;
104+
byte[] strBytes = ("" + i).getBytes();
105+
byte[] paddedBytes = new byte[byteArrSize];
106+
System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length);
107+
108+
Assert.assertEquals(
109+
hasher.hashUnsafeWords(paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize),
110+
hasher.hashUnsafeWords(paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize));
111+
112+
hashcodes.add(hasher.hashUnsafeWords(
113+
paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize));
94114
}
95115

96116
// A very loose bound.

0 commit comments

Comments
 (0)