diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index e8c01d46a84c..5fef74a7970a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.joins import java.io._ +import java.util.concurrent.atomic.LongAdder import com.esotericsoftware.kryo.{Kryo, KryoSerializable} import com.esotericsoftware.kryo.io.{Input, Output} @@ -398,8 +399,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap private var numKeys = 0L // Tracking average number of probes per key lookup. - private var numKeyLookups = 0L - private var numProbes = 0L + private var numKeyLookups = new LongAdder + private var numProbes = new LongAdder // needed by serializer def this() = { @@ -485,8 +486,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap */ def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = { if (isDense) { - numKeyLookups += 1 - numProbes += 1 + numKeyLookups.increment() + numProbes.increment() if (key >= minKey && key <= maxKey) { val value = array((key - minKey).toInt) if (value > 0) { @@ -495,14 +496,14 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap } } else { var pos = firstSlot(key) - numKeyLookups += 1 - numProbes += 1 + numKeyLookups.increment() + numProbes.increment() while (array(pos + 1) != 0) { if (array(pos) == key) { return getRow(array(pos + 1), resultRow) } pos = nextSlot(pos) - numProbes += 1 + numProbes.increment() } } null @@ -530,8 +531,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap */ def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { if (isDense) { - numKeyLookups += 1 - numProbes += 1 + numKeyLookups.increment() + numProbes.increment() if (key >= minKey && key <= maxKey) { val value = array((key - minKey).toInt) if (value > 0) { @@ -540,14 +541,14 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap } } else { var pos = firstSlot(key) - numKeyLookups += 1 - numProbes += 1 + numKeyLookups.increment() + numProbes.increment() while (array(pos + 1) != 0) { if (array(pos) == key) { return valueIter(array(pos + 1), resultRow) } pos = nextSlot(pos) - numProbes += 1 + numProbes.increment() } } null @@ -587,11 +588,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap private def updateIndex(key: Long, address: Long): Unit = { var pos = firstSlot(key) assert(numKeys < array.length / 2) - numKeyLookups += 1 - numProbes += 1 + numKeyLookups.increment() + numProbes.increment() while (array(pos) != key && array(pos + 1) != 0) { pos = nextSlot(pos) - numProbes += 1 + numProbes.increment() } if (array(pos + 1) == 0) { // this is the first value for this key, put the address in array. @@ -723,8 +724,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap writeLong(maxKey) writeLong(numKeys) writeLong(numValues) - writeLong(numKeyLookups) - writeLong(numProbes) + writeLong(numKeyLookups.longValue()) + writeLong(numProbes.longValue()) writeLong(array.length) writeLongArray(writeBuffer, array, array.length) @@ -766,8 +767,10 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap maxKey = readLong() numKeys = readLong() numValues = readLong() - numKeyLookups = readLong() - numProbes = readLong() + numKeyLookups = new LongAdder() + numKeyLookups.add(readLong()) + numProbes = new LongAdder() + numProbes.add(readLong()) val length = readLong().toInt mask = length - 2 @@ -789,7 +792,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap /** * Returns the average number of probes per key lookup. */ - def getAverageProbesPerLookup: Double = numProbes.toDouble / numKeyLookups + def getAverageProbesPerLookup: Double = numProbes.doubleValue() / numKeyLookups.doubleValue() } private[joins] class LongHashedRelation(