Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.joins

import java.io._
import java.util.concurrent.atomic.LongAdder

import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
import com.esotericsoftware.kryo.io.{Input, Output}
Expand Down Expand Up @@ -398,8 +399,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
private var numKeys = 0L

// Tracking average number of probes per key lookup.
private var numKeyLookups = 0L
private var numProbes = 0L
private var numKeyLookups = new LongAdder
private var numProbes = new LongAdder
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised. I think LongToUnsafeRowMap is used in a single thread environment and multi-thread contend should not be an issue here. Do you have any insights about how this fixes the perf issue?

Copy link
Contributor Author

@LuciferYang LuciferYang Dec 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially, I thought these two variables in class scope will affect SIMD optimization of JIT(after java8), we try to add -XX: -UseSuperWord to executor java opts to vertify this view, but no affect with spark-2.1, although this patch can improve performance....


// needed by serializer
def this() = {
Expand Down Expand Up @@ -485,8 +486,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
*/
def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = {
if (isDense) {
numKeyLookups += 1
numProbes += 1
numKeyLookups.increment()
numProbes.increment()
if (key >= minKey && key <= maxKey) {
val value = array((key - minKey).toInt)
if (value > 0) {
Expand All @@ -495,14 +496,14 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
}
} else {
var pos = firstSlot(key)
numKeyLookups += 1
numProbes += 1
numKeyLookups.increment()
numProbes.increment()
while (array(pos + 1) != 0) {
if (array(pos) == key) {
return getRow(array(pos + 1), resultRow)
}
pos = nextSlot(pos)
numProbes += 1
numProbes.increment()
}
}
null
Expand Down Expand Up @@ -530,8 +531,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
*/
def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = {
if (isDense) {
numKeyLookups += 1
numProbes += 1
numKeyLookups.increment()
numProbes.increment()
if (key >= minKey && key <= maxKey) {
val value = array((key - minKey).toInt)
if (value > 0) {
Expand All @@ -540,14 +541,14 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
}
} else {
var pos = firstSlot(key)
numKeyLookups += 1
numProbes += 1
numKeyLookups.increment()
numProbes.increment()
while (array(pos + 1) != 0) {
if (array(pos) == key) {
return valueIter(array(pos + 1), resultRow)
}
pos = nextSlot(pos)
numProbes += 1
numProbes.increment()
}
}
null
Expand Down Expand Up @@ -587,11 +588,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
private def updateIndex(key: Long, address: Long): Unit = {
var pos = firstSlot(key)
assert(numKeys < array.length / 2)
numKeyLookups += 1
numProbes += 1
numKeyLookups.increment()
numProbes.increment()
while (array(pos) != key && array(pos + 1) != 0) {
pos = nextSlot(pos)
numProbes += 1
numProbes.increment()
}
if (array(pos + 1) == 0) {
// this is the first value for this key, put the address in array.
Expand Down Expand Up @@ -723,8 +724,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
writeLong(maxKey)
writeLong(numKeys)
writeLong(numValues)
writeLong(numKeyLookups)
writeLong(numProbes)
writeLong(numKeyLookups.longValue())
writeLong(numProbes.longValue())

writeLong(array.length)
writeLongArray(writeBuffer, array, array.length)
Expand Down Expand Up @@ -766,8 +767,10 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
maxKey = readLong()
numKeys = readLong()
numValues = readLong()
numKeyLookups = readLong()
numProbes = readLong()
numKeyLookups = new LongAdder()
numKeyLookups.add(readLong())
numProbes = new LongAdder()
numProbes.add(readLong())

val length = readLong().toInt
mask = length - 2
Expand All @@ -789,7 +792,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
/**
* Returns the average number of probes per key lookup.
*/
def getAverageProbesPerLookup: Double = numProbes.toDouble / numKeyLookups
def getAverageProbesPerLookup: Double = numProbes.doubleValue() / numKeyLookups.doubleValue()
}

private[joins] class LongHashedRelation(
Expand Down