(numFields);
+ for (int fieldNumber = 0; fieldNumber < numFields; fieldNumber++) {
+ values.update(fieldNumber, get(fieldNumber));
+ }
+ return values;
+ }
+
+ @Override
+ public String toString() {
+ return mkString("[", ",", "]");
+ }
+
+ @Override
+ public String mkString() {
+ return toSeq().mkString();
+ }
+
+ @Override
+ public String mkString(String sep) {
+ return toSeq().mkString(sep);
+ }
+
+ @Override
+ public String mkString(String start, String sep, String end) {
+ return toSeq().mkString(start, sep, end);
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
new file mode 100644
index 0000000000000..4418c92fd6bc1
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
@@ -0,0 +1,226 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.PlatformDependent
+import org.apache.spark.unsafe.array.ByteArrayMethods
+
+/**
+ * Converts Rows into UnsafeRow format. This class is NOT thread-safe.
+ *
+ * @param fieldTypes the data types of the row's columns.
+ */
+class UnsafeRowConverter(fieldTypes: Array[DataType]) {
+
+ def this(schema: StructType) {
+ this(schema.fields.map(_.dataType))
+ }
+
+ /** Re-used pointer to the unsafe row being written */
+ private[this] val unsafeRow = new UnsafeRow()
+
+ /** Functions for encoding each column */
+ private[this] val writers: Array[UnsafeColumnWriter[Any]] = {
+ fieldTypes.map(t => UnsafeColumnWriter.forType(t).asInstanceOf[UnsafeColumnWriter[Any]])
+ }
+
+ /** The size, in bytes, of the fixed-length portion of the row, including the null bitmap */
+ private[this] val fixedLengthSize: Int =
+ (8 * fieldTypes.length) + UnsafeRow.calculateBitSetWidthInBytes(fieldTypes.length)
+
+ /**
+ * Compute the amount of space, in bytes, required to encode the given row.
+ */
+ def getSizeRequirement(row: Row): Int = {
+ var fieldNumber = 0
+ var variableLengthFieldSize: Int = 0
+ while (fieldNumber < writers.length) {
+ if (!row.isNullAt(fieldNumber)) {
+ variableLengthFieldSize += writers(fieldNumber).getSize(row(fieldNumber))
+ }
+ fieldNumber += 1
+ }
+ fixedLengthSize + variableLengthFieldSize
+ }
+
+ /**
+ * Convert the given row into UnsafeRow format.
+ *
+ * @param row the row to convert
+ * @param baseObject the base object of the destination address
+ * @param baseOffset the base offset of the destination address
+ * @return the number of bytes written. This should be equal to `getSizeRequirement(row)`.
+ */
+ def writeRow(row: Row, baseObject: Object, baseOffset: Long): Long = {
+ unsafeRow.pointTo(baseObject, baseOffset, writers.length, null)
+ var fieldNumber = 0
+ var appendCursor: Int = fixedLengthSize
+ while (fieldNumber < writers.length) {
+ if (row.isNullAt(fieldNumber)) {
+ unsafeRow.setNullAt(fieldNumber)
+ } else {
+ appendCursor += writers(fieldNumber).write(
+ row(fieldNumber),
+ fieldNumber,
+ unsafeRow,
+ baseObject,
+ baseOffset,
+ appendCursor)
+ }
+ fieldNumber += 1
+ }
+ appendCursor
+ }
+
+}
+
+/**
+ * Function for writing a column into an UnsafeRow.
+ */
+private abstract class UnsafeColumnWriter[T] {
+ /**
+ * Write a value into an UnsafeRow.
+ *
+ * @param value the value to write
+ * @param columnNumber what column to write it to
+ * @param row a pointer to the unsafe row
+ * @param baseObject the base object of the target row's address
+ * @param baseOffset the base offset of the target row's address
+ * @param appendCursor the offset from the start of the unsafe row to the end of the row;
+ * used for calculating where variable-length data should be written
+ * @return the number of variable-length bytes written
+ */
+ def write(
+ value: T,
+ columnNumber: Int,
+ row: UnsafeRow,
+ baseObject: Object,
+ baseOffset: Long,
+ appendCursor: Int): Int
+
+ /**
+ * Return the number of bytes that are needed to write this variable-length value.
+ */
+ def getSize(value: T): Int
+}
+
+private object UnsafeColumnWriter {
+
+ def forType(dataType: DataType): UnsafeColumnWriter[_] = {
+ dataType match {
+ case IntegerType => IntUnsafeColumnWriter
+ case LongType => LongUnsafeColumnWriter
+ case FloatType => FloatUnsafeColumnWriter
+ case DoubleType => DoubleUnsafeColumnWriter
+ case StringType => StringUnsafeColumnWriter
+ case t =>
+ throw new UnsupportedOperationException(s"Do not know how to write columns of type $t")
+ }
+ }
+}
+
+// ------------------------------------------------------------------------------------------------
+
+private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter
+private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter
+private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter
+private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter
+private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter
+
+private abstract class PrimitiveUnsafeColumnWriter[T] extends UnsafeColumnWriter[T] {
+ def getSize(value: T): Int = 0
+}
+
+private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter[Int] {
+ override def write(
+ value: Int,
+ columnNumber: Int,
+ row: UnsafeRow,
+ baseObject: Object,
+ baseOffset: Long,
+ appendCursor: Int): Int = {
+ row.setInt(columnNumber, value)
+ 0
+ }
+}
+
+private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter[Long] {
+ override def write(
+ value: Long,
+ columnNumber: Int,
+ row: UnsafeRow,
+ baseObject: Object,
+ baseOffset: Long,
+ appendCursor: Int): Int = {
+ row.setLong(columnNumber, value)
+ 0
+ }
+}
+
+private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter[Float] {
+ override def write(
+ value: Float,
+ columnNumber: Int,
+ row: UnsafeRow,
+ baseObject: Object,
+ baseOffset: Long,
+ appendCursor: Int): Int = {
+ row.setFloat(columnNumber, value)
+ 0
+ }
+}
+
+private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter[Double] {
+ override def write(
+ value: Double,
+ columnNumber: Int,
+ row: UnsafeRow,
+ baseObject: Object,
+ baseOffset: Long,
+ appendCursor: Int): Int = {
+ row.setDouble(columnNumber, value)
+ 0
+ }
+}
+
+private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter[UTF8String] {
+ def getSize(value: UTF8String): Int = {
+ 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(value.getBytes.length)
+ }
+
+ override def write(
+ value: UTF8String,
+ columnNumber: Int,
+ row: UnsafeRow,
+ baseObject: Object,
+ baseOffset: Long,
+ appendCursor: Int): Int = {
+ val numBytes = value.getBytes.length
+ PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes)
+ PlatformDependent.copyMemory(
+ value.getBytes,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ baseObject,
+ baseOffset + appendCursor + 8,
+ numBytes
+ )
+ row.setLong(columnNumber, appendCursor)
+ 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
new file mode 100644
index 0000000000000..f00f290ef911a
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.unsafe.memory.{MemoryManager, MemoryAllocator}
+import org.scalatest.{BeforeAndAfterEach, FunSuite, Matchers}
+
+import org.apache.spark.sql.types._
+
+class UnsafeFixedWidthAggregationMapSuite extends FunSuite with Matchers with BeforeAndAfterEach {
+
+ import UnsafeFixedWidthAggregationMap._
+
+ private val groupKeySchema = StructType(StructField("product", StringType) :: Nil)
+ private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil)
+ private def emptyAggregationBuffer: Row = new GenericRow(Array[Any](0))
+
+ private var memoryManager: MemoryManager = null
+
+ override def beforeEach(): Unit = {
+ memoryManager = new MemoryManager(true)
+ }
+
+ override def afterEach(): Unit = {
+ if (memoryManager != null) {
+ memoryManager.cleanUpAllPages()
+ memoryManager = null
+ }
+ }
+
+ test("supported schemas") {
+ assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil)))
+ assert(supportsGroupKeySchema(StructType(StructField("x", StringType) :: Nil)))
+
+ assert(
+ !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil)))
+ assert(
+ !supportsGroupKeySchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil)))
+ }
+
+ test("empty map") {
+ val map = new UnsafeFixedWidthAggregationMap(
+ emptyAggregationBuffer,
+ aggBufferSchema,
+ groupKeySchema,
+ memoryManager,
+ 1024,
+ false
+ )
+ assert(!map.iterator().hasNext)
+ map.free()
+ }
+
+ test("updating values for a single key") {
+ val map = new UnsafeFixedWidthAggregationMap(
+ emptyAggregationBuffer,
+ aggBufferSchema,
+ groupKeySchema,
+ memoryManager,
+ 1024,
+ false
+ )
+ val groupKey = new GenericRow(Array[Any](UTF8String("cats")))
+
+ // Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts)
+ map.getAggregationBuffer(groupKey)
+ val iter = map.iterator()
+ val entry = iter.next()
+ assert(!iter.hasNext)
+ entry.key.getString(0) should be ("cats")
+ entry.value.getInt(0) should be (0)
+
+ // Modifications to rows retrieved from the map should update the values in the map
+ entry.value.setInt(0, 42)
+ map.getAggregationBuffer(groupKey).getInt(0) should be (42)
+
+ map.free()
+ }
+
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
new file mode 100644
index 0000000000000..211bc3333e386
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import java.util.Arrays
+
+import org.scalatest.{FunSuite, Matchers}
+
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.PlatformDependent
+import org.apache.spark.unsafe.array.ByteArrayMethods
+
+class UnsafeRowConverterSuite extends FunSuite with Matchers {
+
+ test("basic conversion with only primitive types") {
+ val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType)
+ val converter = new UnsafeRowConverter(fieldTypes)
+
+ val row = new SpecificMutableRow(fieldTypes)
+ row.setLong(0, 0)
+ row.setLong(1, 1)
+ row.setInt(2, 2)
+
+ val sizeRequired: Int = converter.getSizeRequirement(row)
+ sizeRequired should be (8 + (3 * 8))
+ val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
+ val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
+ numBytesWritten should be (sizeRequired)
+
+ val unsafeRow = new UnsafeRow()
+ unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
+ unsafeRow.getLong(0) should be (0)
+ unsafeRow.getLong(1) should be (1)
+ unsafeRow.getInt(2) should be (2)
+ }
+
+ test("basic conversion with primitive and string types") {
+ val fieldTypes: Array[DataType] = Array(LongType, StringType, StringType)
+ val converter = new UnsafeRowConverter(fieldTypes)
+
+ val row = new SpecificMutableRow(fieldTypes)
+ row.setLong(0, 0)
+ row.setString(1, "Hello")
+ row.setString(2, "World")
+
+ val sizeRequired: Int = converter.getSizeRequirement(row)
+ sizeRequired should be (8 + (8 * 3) +
+ ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8) +
+ ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length + 8))
+ val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
+ val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
+ numBytesWritten should be (sizeRequired)
+
+ val unsafeRow = new UnsafeRow()
+ unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
+ unsafeRow.getLong(0) should be (0)
+ unsafeRow.getString(1) should be ("Hello")
+ unsafeRow.getString(2) should be ("World")
+ }
+
+ test("null handling") {
+ val fieldTypes: Array[DataType] = Array(IntegerType, LongType, FloatType, DoubleType)
+ val converter = new UnsafeRowConverter(fieldTypes)
+
+ val rowWithAllNullColumns: Row = {
+ val r = new SpecificMutableRow(fieldTypes)
+ for (i <- 0 to 3) {
+ r.setNullAt(i)
+ }
+ r
+ }
+
+ val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns)
+ val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8)
+ val numBytesWritten = converter.writeRow(
+ rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET)
+ numBytesWritten should be (sizeRequired)
+
+ val createdFromNull = new UnsafeRow()
+ createdFromNull.pointTo(
+ createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
+ for (i <- 0 to 3) {
+ assert(createdFromNull.isNullAt(i))
+ }
+ createdFromNull.getInt(0) should be (0)
+ createdFromNull.getLong(1) should be (0)
+ assert(java.lang.Float.isNaN(createdFromNull.getFloat(2)))
+ assert(java.lang.Double.isNaN(createdFromNull.getFloat(3)))
+
+ // If we have an UnsafeRow with columns that are initially non-null and we null out those
+ // columns, then the serialized row representation should be identical to what we would get by
+ // creating an entirely null row via the converter
+ val rowWithNoNullColumns: Row = {
+ val r = new SpecificMutableRow(fieldTypes)
+ r.setInt(0, 100)
+ r.setLong(1, 200)
+ r.setFloat(2, 300)
+ r.setDouble(3, 400)
+ r
+ }
+ val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8)
+ converter.writeRow(
+ rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET)
+ val setToNullAfterCreation = new UnsafeRow()
+ setToNullAfterCreation.pointTo(
+ setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
+ setToNullAfterCreation.getInt(0) should be (rowWithNoNullColumns.getInt(0))
+ setToNullAfterCreation.getLong(1) should be (rowWithNoNullColumns.getLong(1))
+ setToNullAfterCreation.getFloat(2) should be (rowWithNoNullColumns.getFloat(2))
+ setToNullAfterCreation.getDouble(3) should be (rowWithNoNullColumns.getDouble(3))
+
+ for (i <- 0 to 3) {
+ setToNullAfterCreation.setNullAt(i)
+ }
+ assert(Arrays.equals(createdFromNullBuffer, setToNullAfterCreationBuffer))
+ }
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 4fc5de7e824fe..2fa602a6082dc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -30,6 +30,7 @@ private[spark] object SQLConf {
val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes"
val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions"
val CODEGEN_ENABLED = "spark.sql.codegen"
+ val UNSAFE_ENABLED = "spark.sql.unsafe.enabled"
val DIALECT = "spark.sql.dialect"
val PARQUET_BINARY_AS_STRING = "spark.sql.parquet.binaryAsString"
@@ -149,6 +150,14 @@ private[sql] class SQLConf extends Serializable {
*/
private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "false").toBoolean
+ /**
+ * When set to true, Spark SQL will use managed memory for certain operations. This option only
+ * takes effect if codegen is enabled.
+ *
+ * Defaults to false as this feature is currently experimental.
+ */
+ private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, "false").toBoolean
+
private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, "true").toBoolean
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index bcd20c06c6dca..04a8538c763c8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -1011,6 +1011,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
def codegenEnabled: Boolean = self.conf.codegenEnabled
+ def unsafeEnabled: Boolean = self.conf.unsafeEnabled
+
def numPartitions: Int = self.conf.numShufflePartitions
def strategies: Seq[Strategy] =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index b1ef6556de1e9..6bb0a5d32cb52 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -17,12 +17,14 @@
package org.apache.spark.sql.execution
+import org.apache.spark.SparkEnv
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.trees._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.memory.MemoryAllocator
case class AggregateEvaluation(
schema: Seq[Attribute],
@@ -41,13 +43,15 @@ case class AggregateEvaluation(
* @param groupingExpressions expressions that are evaluated to determine grouping.
* @param aggregateExpressions expressions that are computed for each group.
* @param child the input data source.
+ * @param unsafeEnabled whether to allow Unsafe-based aggregation buffers to be used.
*/
@DeveloperApi
case class GeneratedAggregate(
partial: Boolean,
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
- child: SparkPlan)
+ child: SparkPlan,
+ unsafeEnabled: Boolean)
extends UnaryNode {
override def requiredChildDistribution: Seq[Distribution] =
@@ -225,6 +229,21 @@ case class GeneratedAggregate(
case e: Expression if groupMap.contains(e) => groupMap(e)
})
+ val aggregationBufferSchema: StructType = StructType.fromAttributes(computationSchema)
+
+ val groupKeySchema: StructType = {
+ val fields = groupingExpressions.zipWithIndex.map { case (expr, idx) =>
+ // This is a dummy field name
+ StructField(idx.toString, expr.dataType, expr.nullable)
+ }
+ StructType(fields)
+ }
+
+ val schemaSupportsUnsafe: Boolean = {
+ UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) &&
+ UnsafeFixedWidthAggregationMap.supportsGroupKeySchema(groupKeySchema)
+ }
+
child.execute().mapPartitions { iter =>
// Builds a new custom class for holding the results of aggregation for a group.
val initialValues = computeFunctions.flatMap(_.initialValues)
@@ -265,7 +284,49 @@ case class GeneratedAggregate(
val resultProjection = resultProjectionBuilder()
Iterator(resultProjection(buffer))
+ } else if (unsafeEnabled && schemaSupportsUnsafe) {
+ log.info("Using Unsafe-based aggregator")
+ val aggregationMap = new UnsafeFixedWidthAggregationMap(
+ newAggregationBuffer(EmptyRow),
+ aggregationBufferSchema,
+ groupKeySchema,
+ SparkEnv.get.unsafeMemoryManager,
+ 1024 * 16,
+ false
+ )
+
+ while (iter.hasNext) {
+ val currentRow: Row = iter.next()
+ val groupKey: Row = groupProjection(currentRow)
+ val aggregationBuffer = aggregationMap.getAggregationBuffer(groupKey)
+ updateProjection.target(aggregationBuffer)(joinedRow(aggregationBuffer, currentRow))
+ }
+
+ new Iterator[Row] {
+ private[this] val mapIterator = aggregationMap.iterator()
+ private[this] val resultProjection = resultProjectionBuilder()
+
+ def hasNext: Boolean = mapIterator.hasNext
+
+ def next(): Row = {
+ val entry = mapIterator.next()
+ val result = resultProjection(joinedRow(entry.key, entry.value))
+ if (hasNext) {
+ result
+ } else {
+ // This is the last element in the iterator, so let's free the buffer. Before we do,
+ // though, we need to make a defensive copy of the result so that we don't return an
+ // object that might contain dangling pointers to the freed memory
+ val resultCopy = result.copy()
+ aggregationMap.free()
+ resultCopy
+ }
+ }
+ }
} else {
+ if (unsafeEnabled) {
+ log.info("Not using Unsafe-based aggregator because it is not supported for this schema")
+ }
val buffers = new java.util.HashMap[Row, MutableRow]()
var currentRow: Row = null
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 030ef118f75d4..4c0369f0dbde4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -140,7 +140,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
partial = true,
groupingExpressions,
partialComputation,
- planLater(child))) :: Nil
+ planLater(child),
+ unsafeEnabled),
+ unsafeEnabled) :: Nil
// Cases where some aggregate can not be codegened
case PartialAggregation(
diff --git a/unsafe/pom.xml b/unsafe/pom.xml
new file mode 100644
index 0000000000000..c40efef2eb109
--- /dev/null
+++ b/unsafe/pom.xml
@@ -0,0 +1,58 @@
+
+
+
+
+ 4.0.0
+
+ org.apache.spark
+ spark-parent_2.10
+ 1.4.0-SNAPSHOT
+ ../pom.xml
+
+
+ org.apache.spark
+ spark-unsafe_2.10
+ jar
+ Spark Project Unsafe
+ http://spark.apache.org/
+
+ unsafe
+
+
+
+
+ junit
+ junit
+ test
+
+
+ com.novocode
+ junit-interface
+ test
+
+
+ com.google.code.findbugs
+ jsr305
+
+
+
+ target/scala-${scala.binary.version}/classes
+ target/scala-${scala.binary.version}/test-classes
+
+
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java
new file mode 100644
index 0000000000000..91b2f9aa43921
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/PlatformDependent.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe;
+
+import java.lang.reflect.Field;
+
+import sun.misc.Unsafe;
+
+public final class PlatformDependent {
+
+ public static final Unsafe UNSAFE;
+
+ public static final int BYTE_ARRAY_OFFSET;
+
+ public static final int INT_ARRAY_OFFSET;
+
+ public static final int LONG_ARRAY_OFFSET;
+
+ public static final int DOUBLE_ARRAY_OFFSET;
+
+ /**
+ * Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to
+ * allow safepoint polling during a large copy.
+ */
+ private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L;
+
+ static {
+ sun.misc.Unsafe unsafe;
+ try {
+ Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe");
+ unsafeField.setAccessible(true);
+ unsafe = (sun.misc.Unsafe) unsafeField.get(null);
+ } catch (Throwable cause) {
+ unsafe = null;
+ }
+ UNSAFE = unsafe;
+
+ if (UNSAFE != null) {
+ BYTE_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(byte[].class);
+ INT_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(int[].class);
+ LONG_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(long[].class);
+ DOUBLE_ARRAY_OFFSET = UNSAFE.arrayBaseOffset(double[].class);
+ } else {
+ BYTE_ARRAY_OFFSET = 0;
+ INT_ARRAY_OFFSET = 0;
+ LONG_ARRAY_OFFSET = 0;
+ DOUBLE_ARRAY_OFFSET = 0;
+ }
+ }
+
+ static public void copyMemory(
+ Object src,
+ long srcOffset,
+ Object dst,
+ long dstOffset,
+ long length) {
+ while (length > 0) {
+ long size = Math.min(length, UNSAFE_COPY_THRESHOLD);
+ UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size);
+ length -= size;
+ srcOffset += size;
+ dstOffset += size;
+ }
+ }
+
+ /**
+ * Raises an exception bypassing compiler checks for checked exceptions.
+ */
+ public static void throwException(Throwable t) {
+ UNSAFE.throwException(t);
+ }
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java
new file mode 100644
index 0000000000000..963b8398614c3
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.array;
+
+import org.apache.spark.unsafe.PlatformDependent;
+
+public class ByteArrayMethods {
+
+ private ByteArrayMethods() {
+ // Private constructor, since this class only contains static methods.
+ }
+
+ public static int roundNumberOfBytesToNearestWord(int numBytes) {
+ int remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8`
+ if (remainder == 0) {
+ return numBytes;
+ } else {
+ return numBytes + (8 - remainder);
+ }
+ }
+
+ /**
+ * Optimized equality check for equal-length byte arrays.
+ * @return true if the arrays are equal, false otherwise
+ */
+ public static boolean arrayEquals(
+ Object leftBaseObject,
+ long leftBaseOffset,
+ Object rightBaseObject,
+ long rightBaseOffset,
+ long arrayLengthInBytes) {
+ // TODO: this can be optimized by comparing words and falling back to individual byte
+ // comparison only at the end of the array (Guava's UnsignedBytes has an implementation of this)
+ for (int i = 0; i < arrayLengthInBytes; i++) {
+ final byte left =
+ PlatformDependent.UNSAFE.getByte(leftBaseObject, leftBaseOffset + i);
+ final byte right =
+ PlatformDependent.UNSAFE.getByte(rightBaseObject, rightBaseOffset + i);
+ if (left != right) return false;
+ }
+ return true;
+ }
+
+ /**
+ * Optimized byte array equality check for 8-byte-word-aligned byte arrays.
+ * @return true if the arrays are equal, false otherwise
+ */
+ public static boolean wordAlignedArrayEquals(
+ Object leftBaseObject,
+ long leftBaseOffset,
+ Object rightBaseObject,
+ long rightBaseOffset,
+ long arrayLengthInBytes) {
+ for (int i = 0; i < arrayLengthInBytes; i += 8) {
+ final long left =
+ PlatformDependent.UNSAFE.getLong(leftBaseObject, leftBaseOffset + i);
+ final long right =
+ PlatformDependent.UNSAFE.getLong(rightBaseObject, rightBaseOffset + i);
+ if (left != right) return false;
+ }
+ return true;
+ }
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java
new file mode 100644
index 0000000000000..18d1f0d2d7eb2
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.array;
+
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+
+/**
+ * An array of long values. Compared with native JVM arrays, this:
+ *
+ * supports using both in-heap and off-heap memory
+ * has no bound checking, and thus can crash the JVM process when assert is turned off
+ *
+ */
+public final class LongArray {
+
+ // This is a long so that we perform long multiplications when computing offsets.
+ private static final long WIDTH = 8;
+
+ private final MemoryBlock memory;
+ private final Object baseObj;
+ private final long baseOffset;
+
+ private final long length;
+
+ public LongArray(MemoryBlock memory) {
+ assert memory.size() % WIDTH == 0 : "Memory not aligned (" + memory.size() + ")";
+ assert memory.size() < (long) Integer.MAX_VALUE * 8: "Array size > 4 billion elements";
+ this.memory = memory;
+ this.baseObj = memory.getBaseObject();
+ this.baseOffset = memory.getBaseOffset();
+ this.length = memory.size() / WIDTH;
+ }
+
+ public MemoryBlock memoryBlock() {
+ return memory;
+ }
+
+ /**
+ * Returns the number of elements this array can hold.
+ */
+ public long size() {
+ return length;
+ }
+
+ /**
+ * Sets the value at position {@code index}.
+ */
+ public void set(int index, long value) {
+ assert index >= 0 : "index (" + index + ") should >= 0";
+ assert index < length : "index (" + index + ") should < length (" + length + ")";
+ PlatformDependent.UNSAFE.putLong(baseObj, baseOffset + index * WIDTH, value);
+ }
+
+ /**
+ * Returns the value at position {@code index}.
+ */
+ public long get(int index) {
+ assert index >= 0 : "index (" + index + ") should >= 0";
+ assert index < length : "index (" + index + ") should < length (" + length + ")";
+ return PlatformDependent.UNSAFE.getLong(baseObj, baseOffset + index * WIDTH);
+ }
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java
new file mode 100644
index 0000000000000..f72e07fce92fd
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSet.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.bitset;
+
+import org.apache.spark.unsafe.array.LongArray;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+
+/**
+ * A fixed size uncompressed bit set backed by a {@link LongArray}.
+ *
+ * Each bit occupies exactly one bit of storage.
+ */
+public final class BitSet {
+
+ /** A long array for the bits. */
+ private final LongArray words;
+
+ /** Length of the long array. */
+ private final int numWords;
+
+ private final Object baseObject;
+ private final long baseOffset;
+
+ /**
+ * Creates a new {@link BitSet} using the specified memory block. Size of the memory block must be
+ * multiple of 8 bytes (i.e. 64 bits).
+ */
+ public BitSet(MemoryBlock memory) {
+ words = new LongArray(memory);
+ assert (words.size() <= Integer.MAX_VALUE);
+ numWords = (int) words.size();
+ baseObject = words.memoryBlock().getBaseObject();
+ baseOffset = words.memoryBlock().getBaseOffset();
+ }
+
+ public MemoryBlock memoryBlock() {
+ return words.memoryBlock();
+ }
+
+ /**
+ * Returns the number of bits in this {@code BitSet}.
+ */
+ public long capacity() {
+ return numWords * 64;
+ }
+
+ /**
+ * Sets the bit at the specified index to {@code true}.
+ */
+ public void set(int index) {
+ assert index < numWords * 64 : "index (" + index + ") should < length (" + numWords * 64 + ")";
+ BitSetMethods.set(baseObject, baseOffset, index);
+ }
+
+ /**
+ * Sets the bit at the specified index to {@code false}.
+ */
+ public void unset(int index) {
+ assert index < numWords * 64 : "index (" + index + ") should < length (" + numWords * 64 + ")";
+ BitSetMethods.unset(baseObject, baseOffset, index);
+ }
+
+ /**
+ * Returns {@code true} if the bit is set at the specified index.
+ */
+ public boolean isSet(int index) {
+ assert index < numWords * 64 : "index (" + index + ") should < length (" + numWords * 64 + ")";
+ return BitSetMethods.isSet(baseObject, baseOffset, index);
+ }
+
+ /**
+ * Returns the index of the first bit that is set to true that occurs on or after the
+ * specified starting index. If no such bit exists then {@code -1} is returned.
+ *
+ * To iterate over the true bits in a BitSet, use the following loop:
+ *
+ *
+ * for (long i = bs.nextSetBit(0); i >= 0; i = bs.nextSetBit(i + 1)) {
+ * // operate on index i here
+ * }
+ *
+ *
+ *
+ * @param fromIndex the index to start checking from (inclusive)
+ * @return the index of the next set bit, or -1 if there is no such bit
+ */
+ public int nextSetBit(int fromIndex) {
+ return BitSetMethods.nextSetBit(baseObject, baseOffset, fromIndex, numWords);
+ }
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java
new file mode 100644
index 0000000000000..f30626d8f4317
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/bitset/BitSetMethods.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.bitset;
+
+import org.apache.spark.unsafe.PlatformDependent;
+
+/**
+ * Methods for working with fixed-size uncompressed bitsets.
+ *
+ * We assume that the bitset data is word-aligned (that is, a multiple of 8 bytes in length).
+ *
+ * Each bit occupies exactly one bit of storage.
+ */
+public final class BitSetMethods {
+
+ private static final long WORD_SIZE = 8;
+
+ private BitSetMethods() {
+ // Make the default constructor private, since this only holds static methods.
+ }
+
+ /**
+ * Sets the bit at the specified index to {@code true}.
+ */
+ public static void set(Object baseObject, long baseOffset, int index) {
+ assert index >= 0 : "index (" + index + ") should >= 0";
+ final long mask = 1L << (index & 0x3f); // mod 64 and shift
+ final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE;
+ final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset);
+ PlatformDependent.UNSAFE.putLong(baseObject, wordOffset, word | mask);
+ }
+
+ /**
+ * Sets the bit at the specified index to {@code false}.
+ */
+ public static void unset(Object baseObject, long baseOffset, int index) {
+ assert index >= 0 : "index (" + index + ") should >= 0";
+ final long mask = 1L << (index & 0x3f); // mod 64 and shift
+ final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE;
+ final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset);
+ PlatformDependent.UNSAFE.putLong(baseObject, wordOffset, word & ~mask);
+ }
+
+ /**
+ * Returns {@code true} if the bit is set at the specified index.
+ */
+ public static boolean isSet(Object baseObject, long baseOffset, int index) {
+ assert index >= 0 : "index (" + index + ") should >= 0";
+ final long mask = 1L << (index & 0x3f); // mod 64 and shift
+ final long wordOffset = baseOffset + (index >> 6) * WORD_SIZE;
+ final long word = PlatformDependent.UNSAFE.getLong(baseObject, wordOffset);
+ return (word & mask) != 0;
+ }
+
+ /**
+ * Returns {@code true} if any bit is set.
+ */
+ public static boolean anySet(Object baseObject, long baseOffset, long bitSetWidthInBytes) {
+ for (int i = 0; i <= bitSetWidthInBytes; i++) {
+ if (PlatformDependent.UNSAFE.getByte(baseObject, baseOffset + i) != 0) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ /**
+ * Returns the index of the first bit that is set to true that occurs on or after the
+ * specified starting index. If no such bit exists then {@code -1} is returned.
+ *
+ * To iterate over the true bits in a BitSet, use the following loop:
+ *
+ *
+ * for (long i = bs.nextSetBit(0, sizeInWords); i >= 0; i = bs.nextSetBit(i + 1, sizeInWords)) {
+ * // operate on index i here
+ * }
+ *
+ *
+ *
+ * @param fromIndex the index to start checking from (inclusive)
+ * @param bitsetSizeInWords the size of the bitset, measured in 8-byte words
+ * @return the index of the next set bit, or -1 if there is no such bit
+ */
+ public static int nextSetBit(
+ Object baseObject,
+ long baseOffset,
+ int fromIndex,
+ int bitsetSizeInWords) {
+ int wi = fromIndex >> 6;
+ if (wi >= bitsetSizeInWords) {
+ return -1;
+ }
+
+ // Try to find the next set bit in the current word
+ final int subIndex = fromIndex & 0x3f;
+ long word =
+ PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + wi * WORD_SIZE) >> subIndex;
+ if (word != 0) {
+ return (wi << 6) + subIndex + java.lang.Long.numberOfTrailingZeros(word);
+ }
+
+ // Find the next set bit in the rest of the words
+ wi += 1;
+ while (wi < bitsetSizeInWords) {
+ word = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + wi * WORD_SIZE);
+ if (word != 0) {
+ return (wi << 6) + java.lang.Long.numberOfTrailingZeros(word);
+ }
+ wi += 1;
+ }
+
+ return -1;
+ }
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
new file mode 100644
index 0000000000000..85cd02469adb7
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.hash;
+
+import org.apache.spark.unsafe.PlatformDependent;
+
+/**
+ * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction.
+ */
+public final class Murmur3_x86_32 {
+ private static final int C1 = 0xcc9e2d51;
+ private static final int C2 = 0x1b873593;
+
+ private final int seed;
+
+ public Murmur3_x86_32(int seed) {
+ this.seed = seed;
+ }
+
+ @Override
+ public String toString() {
+ return "Murmur3_32(seed=" + seed + ")";
+ }
+
+ public int hashInt(int input) {
+ int k1 = mixK1(input);
+ int h1 = mixH1(seed, k1);
+
+ return fmix(h1, 4);
+ }
+
+ public int hashUnsafeWords(Object baseObject, long baseOffset, int lengthInBytes) {
+ // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method.
+ assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)";
+ int h1 = seed;
+ for (int offset = 0; offset < lengthInBytes; offset += 4) {
+ int halfWord = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset);
+ int k1 = mixK1(halfWord);
+ h1 = mixH1(h1, k1);
+ }
+ return fmix(h1, lengthInBytes);
+ }
+
+ public int hashLong(long input) {
+ int low = (int) input;
+ int high = (int) (input >>> 32);
+
+ int k1 = mixK1(low);
+ int h1 = mixH1(seed, k1);
+
+ k1 = mixK1(high);
+ h1 = mixH1(h1, k1);
+
+ return fmix(h1, 8);
+ }
+
+ private static int mixK1(int k1) {
+ k1 *= C1;
+ k1 = Integer.rotateLeft(k1, 15);
+ k1 *= C2;
+ return k1;
+ }
+
+ private static int mixH1(int h1, int k1) {
+ h1 ^= k1;
+ h1 = Integer.rotateLeft(h1, 13);
+ h1 = h1 * 5 + 0xe6546b64;
+ return h1;
+ }
+
+ // Finalization mix - force all bits of a hash block to avalanche
+ private static int fmix(int h1, int length) {
+ h1 ^= length;
+ h1 ^= h1 >>> 16;
+ h1 *= 0x85ebca6b;
+ h1 ^= h1 >>> 13;
+ h1 *= 0xc2b2ae35;
+ h1 ^= h1 >>> 16;
+ return h1;
+ }
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
new file mode 100644
index 0000000000000..a9a72cdb36b0b
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java
@@ -0,0 +1,552 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.map;
+
+import java.lang.Override;
+import java.lang.UnsupportedOperationException;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+
+import org.apache.spark.unsafe.*;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.array.LongArray;
+import org.apache.spark.unsafe.bitset.BitSet;
+import org.apache.spark.unsafe.hash.Murmur3_x86_32;
+import org.apache.spark.unsafe.memory.*;
+
+/**
+ * An append-only hash map where keys and values are contiguous regions of bytes.
+ *
+ * This is backed by a power-of-2-sized hash table, using quadratic probing with triangular numbers,
+ * which is guaranteed to exhaust the space.
+ *
+ * Note that even though we use long for indexing, the map can support up to 2^31 keys because
+ * we use 32 bit MurmurHash. In either case, if the key cardinality is so high, you should probably
+ * be using sorting instead of hashing for better cache locality.
+ *
+ * This class is not thread safe.
+ */
+public final class BytesToBytesMap {
+
+ private static final Murmur3_x86_32 HASHER = new Murmur3_x86_32(0);
+
+ private static final HashMapGrowthStrategy growthStrategy = HashMapGrowthStrategy.DOUBLING;
+
+ private final MemoryManager memoryManager;
+
+ /**
+ * A linked list for tracking all allocated data pages so that we can free all of our memory.
+ */
+ private final List dataPages = new LinkedList();
+
+ /**
+ * The data page that will be used to store keys and values for new hashtable entries. When this
+ * page becomes full, a new page will be allocated and this pointer will change to point to that
+ * new page.
+ */
+ private MemoryBlock currentDataPage = null;
+
+ /**
+ * Offset into `currentDataPage` that points to the location where new data can be inserted into
+ * the page.
+ */
+ private long pageCursor = 0;
+
+ /**
+ * The size of the data pages that hold key and value data. Map entries cannot span multiple
+ * pages, so this limits the maximum entry size.
+ */
+ private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes
+
+ // This choice of page table size and page size means that we can address up to 500 gigabytes
+ // of memory.
+
+ /**
+ * A single array to store the key and value.
+ *
+ * Position {@code 2 * i} in the array is used to track a pointer to the key at index {@code i},
+ * while position {@code 2 * i + 1} in the array holds key's full 32-bit hashcode.
+ */
+ private LongArray longArray;
+ // TODO: we're wasting 32 bits of space here; we can probably store fewer bits of the hashcode
+ // and exploit word-alignment to use fewer bits to hold the address. This might let us store
+ // only one long per map entry, increasing the chance that this array will fit in cache at the
+ // expense of maybe performing more lookups if we have hash collisions. Say that we stored only
+ // 27 bits of the hashcode and 37 bits of the address. 37 bits is enough to address 1 terabyte
+ // of RAM given word-alignment. If we use 13 bits of this for our page table, that gives us a
+ // maximum page size of 2^24 * 8 = ~134 megabytes per page. This change will require us to store
+ // full base addresses in the page table for off-heap mode so that we can reconstruct the full
+ // absolute memory addresses.
+
+ /**
+ * A {@link BitSet} used to track location of the map where the key is set.
+ * Size of the bitset should be half of the size of the long array.
+ */
+ private BitSet bitset;
+
+ private final double loadFactor;
+
+ /**
+ * Number of keys defined in the map.
+ */
+ private int size;
+
+ /**
+ * The map will be expanded once the number of keys exceeds this threshold.
+ */
+ private int growthThreshold;
+
+ /**
+ * Mask for truncating hashcodes so that they do not exceed the long array's size.
+ */
+ private int mask;
+
+ /**
+ * Return value of {@link BytesToBytesMap#lookup(Object, long, int)}.
+ */
+ private final Location loc;
+
+ private final boolean enablePerfMetrics;
+
+ private long timeSpentResizingMs = 0;
+
+ private long numProbes = 0;
+
+ private long numKeyLookups = 0;
+
+ private long numHashCollisions = 0;
+
+ public BytesToBytesMap(
+ MemoryManager memoryManager,
+ int initialCapacity,
+ double loadFactor,
+ boolean enablePerfMetrics) {
+ this.memoryManager = memoryManager;
+ this.loadFactor = loadFactor;
+ this.loc = new Location();
+ this.enablePerfMetrics = enablePerfMetrics;
+ allocate(initialCapacity);
+ }
+
+ public BytesToBytesMap(MemoryManager memoryManager, int initialCapacity) {
+ this(memoryManager, initialCapacity, 0.70, false);
+ }
+
+ public BytesToBytesMap(
+ MemoryManager memoryManager,
+ int initialCapacity,
+ boolean enablePerfMetrics) {
+ this(memoryManager, initialCapacity, 0.70, enablePerfMetrics);
+ }
+
+ @Override
+ protected void finalize() throws Throwable {
+ try {
+ // In case the programmer forgot to call `free()`, try to perform that cleanup now:
+ free();
+ } finally {
+ super.finalize();
+ }
+ }
+
+ /**
+ * Returns the number of keys defined in the map.
+ */
+ public int size() { return size; }
+
+ /**
+ * Returns an iterator for iterating over the entries of this map.
+ *
+ * For efficiency, all calls to `next()` will return the same {@link Location} object.
+ *
+ * If any other lookups or operations are performed on this map while iterating over it, including
+ * `lookup()`, the behavior of the returned iterator is undefined.
+ */
+ public Iterator iterator() {
+ return new Iterator() {
+
+ private int nextPos = bitset.nextSetBit(0);
+
+ @Override
+ public boolean hasNext() {
+ return nextPos != -1;
+ }
+
+ @Override
+ public Location next() {
+ final int pos = nextPos;
+ nextPos = bitset.nextSetBit(nextPos + 1);
+ return loc.with(pos, 0, true);
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+ };
+ }
+
+ /**
+ * Looks up a key, and return a {@link Location} handle that can be used to test existence
+ * and read/write values.
+ *
+ * This function always return the same {@link Location} instance to avoid object allocation.
+ */
+ public Location lookup(
+ Object keyBaseObject,
+ long keyBaseOffset,
+ int keyRowLengthBytes) {
+ if (enablePerfMetrics) {
+ numKeyLookups++;
+ }
+ final int hashcode = HASHER.hashUnsafeWords(keyBaseObject, keyBaseOffset, keyRowLengthBytes);
+ int pos = hashcode & mask;
+ int step = 1;
+ while (true) {
+ if (enablePerfMetrics) {
+ numProbes++;
+ }
+ if (!bitset.isSet(pos)) {
+ // This is a new key.
+ return loc.with(pos, hashcode, false);
+ } else {
+ long stored = longArray.get(pos * 2 + 1);
+ if ((int) (stored) == hashcode) {
+ // Full hash code matches. Let's compare the keys for equality.
+ loc.with(pos, hashcode, true);
+ if (loc.getKeyLength() == keyRowLengthBytes) {
+ final MemoryLocation keyAddress = loc.getKeyAddress();
+ final Object storedKeyBaseObject = keyAddress.getBaseObject();
+ final long storedKeyBaseOffset = keyAddress.getBaseOffset();
+ final boolean areEqual = ByteArrayMethods.wordAlignedArrayEquals(
+ keyBaseObject,
+ keyBaseOffset,
+ storedKeyBaseObject,
+ storedKeyBaseOffset,
+ keyRowLengthBytes
+ );
+ if (areEqual) {
+ return loc;
+ } else {
+ if (enablePerfMetrics) {
+ numHashCollisions++;
+ }
+ }
+ }
+ }
+ }
+ pos = (pos + step) & mask;
+ step++;
+ }
+ }
+
+ /**
+ * Handle returned by {@link BytesToBytesMap#lookup(Object, long, int)} function.
+ */
+ public final class Location {
+ /** An index into the hash map's Long array */
+ private int pos;
+ /** True if this location points to a position where a key is defined, false otherwise */
+ private boolean isDefined;
+ /**
+ * The hashcode of the most recent key passed to
+ * {@link BytesToBytesMap#lookup(Object, long, int)}. Caching this hashcode here allows us to
+ * avoid re-hashing the key when storing a value for that key.
+ */
+ private int keyHashcode;
+ private final MemoryLocation keyMemoryLocation = new MemoryLocation();
+ private final MemoryLocation valueMemoryLocation = new MemoryLocation();
+ private int keyLength;
+ private int valueLength;
+
+ private void updateAddressesAndSizes(long fullKeyAddress) {
+ final Object page = memoryManager.getPage(fullKeyAddress);
+ final long keyOffsetInPage = memoryManager.getOffsetInPage(fullKeyAddress);
+ keyMemoryLocation.setObjAndOffset(page, keyOffsetInPage + 8);
+ keyLength = (int) PlatformDependent.UNSAFE.getLong(page, keyOffsetInPage);
+ valueMemoryLocation.setObjAndOffset(page, keyOffsetInPage + 8 + keyLength + 8);
+ valueLength = (int) PlatformDependent.UNSAFE.getLong(page, keyOffsetInPage + 8 + keyLength);
+ }
+
+ Location with(int pos, int keyHashcode, boolean isDefined) {
+ this.pos = pos;
+ this.isDefined = isDefined;
+ this.keyHashcode = keyHashcode;
+ if (isDefined) {
+ final long fullKeyAddress = longArray.get(pos * 2);
+ updateAddressesAndSizes(fullKeyAddress);
+ }
+ return this;
+ }
+
+ /**
+ * Returns true if the key is defined at this position, and false otherwise.
+ */
+ public boolean isDefined() {
+ return isDefined;
+ }
+
+ /**
+ * Returns the address of the key defined at this position.
+ * This points to the first byte of the key data.
+ * Unspecified behavior if the key is not defined.
+ * For efficiency reasons, calls to this method always returns the same MemoryLocation object.
+ */
+ public MemoryLocation getKeyAddress() {
+ assert (isDefined);
+ return keyMemoryLocation;
+ }
+
+ /**
+ * Returns the length of the key defined at this position.
+ * Unspecified behavior if the key is not defined.
+ */
+ public int getKeyLength() {
+ assert (isDefined);
+ return keyLength;
+ }
+
+ /**
+ * Returns the address of the value defined at this position.
+ * This points to the first byte of the value data.
+ * Unspecified behavior if the key is not defined.
+ * For efficiency reasons, calls to this method always returns the same MemoryLocation object.
+ */
+ public MemoryLocation getValueAddress() {
+ assert (isDefined);
+ return valueMemoryLocation;
+ }
+
+ /**
+ * Returns the length of the value defined at this position.
+ * Unspecified behavior if the key is not defined.
+ */
+ public int getValueLength() {
+ assert (isDefined);
+ return valueLength;
+ }
+
+ /**
+ * Store a new key and value. This method may only be called once for a given key; if you want
+ * to update the value associated with a key, then you can directly manipulate the bytes stored
+ * at the value address.
+ *
+ * It is only valid to call this method immediately after calling `lookup()` using the same key.
+ *
+ * After calling this method, calls to `get[Key|Value]Address()` and `get[Key|Value]Length`
+ * will return information on the data stored by this `putNewKey` call.
+ *
+ * As an example usage, here's the proper way to store a new key:
+ *
+ *
+ * Location loc = map.lookup(keyBaseOffset, keyBaseObject, keyLengthInBytes);
+ * if (!loc.isDefined()) {
+ * loc.putNewKey(keyBaseOffset, keyBaseObject, keyLengthInBytes, ...)
+ * }
+ *
+ *
+ * Unspecified behavior if the key is not defined.
+ */
+ public void putNewKey(
+ Object keyBaseObject,
+ long keyBaseOffset,
+ int keyLengthBytes,
+ Object valueBaseObject,
+ long valueBaseOffset,
+ int valueLengthBytes) {
+ assert (!isDefined) : "Can only set value once for a key";
+ isDefined = true;
+ assert (keyLengthBytes % 8 == 0);
+ assert (valueLengthBytes % 8 == 0);
+ // Here, we'll copy the data into our data pages. Because we only store a relative offset from
+ // the key address instead of storing the absolute address of the value, the key and value
+ // must be stored in the same memory page.
+ final long requiredSize = 8 + 8 + keyLengthBytes + valueLengthBytes;
+ assert(requiredSize <= PAGE_SIZE_BYTES);
+ size++;
+ bitset.set(pos);
+
+ // If there's not enough space in the current page, allocate a new page:
+ if (currentDataPage == null || PAGE_SIZE_BYTES - pageCursor < requiredSize) {
+ MemoryBlock newPage = memoryManager.allocatePage(PAGE_SIZE_BYTES);
+ dataPages.add(newPage);
+ pageCursor = 0;
+ currentDataPage = newPage;
+ }
+
+ // Compute all of our offsets up-front:
+ final Object pageBaseObject = currentDataPage.getBaseObject();
+ final long pageBaseOffset = currentDataPage.getBaseOffset();
+ final long keySizeOffsetInPage = pageBaseOffset + pageCursor;
+ pageCursor += 8;
+ final long keyDataOffsetInPage = pageBaseOffset + pageCursor;
+ pageCursor += keyLengthBytes;
+ final long valueSizeOffsetInPage = pageBaseOffset + pageCursor;
+ pageCursor += 8;
+ final long valueDataOffsetInPage = pageBaseOffset + pageCursor;
+ pageCursor += valueLengthBytes;
+
+ // Copy the key
+ PlatformDependent.UNSAFE.putLong(pageBaseObject, keySizeOffsetInPage, keyLengthBytes);
+ PlatformDependent.UNSAFE.copyMemory(
+ keyBaseObject, keyBaseOffset, pageBaseObject, keyDataOffsetInPage, keyLengthBytes);
+ // Copy the value
+ PlatformDependent.UNSAFE.putLong(pageBaseObject, valueSizeOffsetInPage, valueLengthBytes);
+ PlatformDependent.UNSAFE.copyMemory(
+ valueBaseObject, valueBaseOffset, pageBaseObject, valueDataOffsetInPage, valueLengthBytes);
+
+ final long storedKeyAddress = memoryManager.encodePageNumberAndOffset(
+ currentDataPage, keySizeOffsetInPage);
+ longArray.set(pos * 2, storedKeyAddress);
+ longArray.set(pos * 2 + 1, keyHashcode);
+ updateAddressesAndSizes(storedKeyAddress);
+ isDefined = true;
+ if (size > growthThreshold) {
+ growAndRehash();
+ }
+ }
+ }
+
+ /**
+ * Allocate new data structures for this map. When calling this outside of the constructor,
+ * make sure to keep references to the old data structures so that you can free them.
+ *
+ * @param capacity the new map capacity
+ */
+ private void allocate(int capacity) {
+ capacity = Math.max((int) Math.min(Integer.MAX_VALUE, nextPowerOf2(capacity)), 64);
+ longArray = new LongArray(memoryManager.allocator.allocate(capacity * 8 * 2));
+ bitset = new BitSet(memoryManager.allocator.allocate(capacity / 8).zero());
+
+ this.growthThreshold = (int) (capacity * loadFactor);
+ this.mask = capacity - 1;
+ }
+
+ /**
+ * Free all allocated memory associated with this map, including the storage for keys and values
+ * as well as the hash map array itself.
+ *
+ * This method is idempotent.
+ */
+ public void free() {
+ if (longArray != null) {
+ memoryManager.allocator.free(longArray.memoryBlock());
+ longArray = null;
+ }
+ if (bitset != null) {
+ memoryManager.allocator.free(bitset.memoryBlock());
+ bitset = null;
+ }
+ Iterator dataPagesIterator = dataPages.iterator();
+ while (dataPagesIterator.hasNext()) {
+ memoryManager.freePage(dataPagesIterator.next());
+ dataPagesIterator.remove();
+ }
+ assert(dataPages.isEmpty());
+ }
+
+ /** Returns the total amount of memory, in bytes, consumed by this map's managed structures. */
+ public long getTotalMemoryConsumption() {
+ return (
+ dataPages.size() * PAGE_SIZE_BYTES +
+ bitset.memoryBlock().size() +
+ longArray.memoryBlock().size());
+ }
+
+ /**
+ * Returns the total amount of time spent resizing this map (in milliseconds).
+ */
+ public long getTimeSpentResizingMs() {
+ if (!enablePerfMetrics) {
+ throw new IllegalStateException();
+ }
+ return timeSpentResizingMs;
+ }
+
+
+ /**
+ * Returns the average number of probes per key lookup.
+ */
+ public double getAverageProbesPerLookup() {
+ if (!enablePerfMetrics) {
+ throw new IllegalStateException();
+ }
+ return (1.0 * numProbes) / numKeyLookups;
+ }
+
+ public long getNumHashCollisions() {
+ if (!enablePerfMetrics) {
+ throw new IllegalStateException();
+ }
+ return numHashCollisions;
+ }
+
+ /**
+ * Grows the size of the hash table and re-hash everything.
+ */
+ private void growAndRehash() {
+ long resizeStartTime = -1;
+ if (enablePerfMetrics) {
+ resizeStartTime = System.currentTimeMillis();
+ }
+ // Store references to the old data structures to be used when we re-hash
+ final LongArray oldLongArray = longArray;
+ final BitSet oldBitSet = bitset;
+ final int oldCapacity = (int) oldBitSet.capacity();
+
+ // Allocate the new data structures
+ allocate(Math.min(Integer.MAX_VALUE, growthStrategy.nextCapacity(oldCapacity)));
+
+ // Re-mask (we don't recompute the hashcode because we stored all 32 bits of it)
+ for (int pos = oldBitSet.nextSetBit(0); pos >= 0; pos = oldBitSet.nextSetBit(pos + 1)) {
+ final long keyPointer = oldLongArray.get(pos * 2);
+ final int hashcode = (int) oldLongArray.get(pos * 2 + 1);
+ int newPos = hashcode & mask;
+ int step = 1;
+ boolean keepGoing = true;
+
+ // No need to check for equality here when we insert so this has one less if branch than
+ // the similar code path in addWithoutResize.
+ while (keepGoing) {
+ if (!bitset.isSet(newPos)) {
+ bitset.set(newPos);
+ longArray.set(newPos * 2, keyPointer);
+ longArray.set(newPos * 2 + 1, hashcode);
+ keepGoing = false;
+ } else {
+ newPos = (newPos + step) & mask;
+ step++;
+ }
+ }
+ }
+
+ // Deallocate the old data structures.
+ memoryManager.allocator.free(oldLongArray.memoryBlock());
+ memoryManager.allocator.free(oldBitSet.memoryBlock());
+ if (enablePerfMetrics) {
+ timeSpentResizingMs += System.currentTimeMillis() - resizeStartTime;
+ }
+ }
+
+ /** Returns the next number greater or equal num that is power of 2. */
+ private static long nextPowerOf2(long num) {
+ final long highBit = Long.highestOneBit(num);
+ return (highBit == num) ? num : highBit << 1;
+ }
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java
new file mode 100644
index 0000000000000..7c321baffe82d
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.map;
+
+/**
+ * Interface that defines how we can grow the size of a hash map when it is over a threshold.
+ */
+public interface HashMapGrowthStrategy {
+
+ int nextCapacity(int currentCapacity);
+
+ /**
+ * Double the size of the hash map every time.
+ */
+ HashMapGrowthStrategy DOUBLING = new Doubling();
+
+ class Doubling implements HashMapGrowthStrategy {
+ @Override
+ public int nextCapacity(int currentCapacity) {
+ return currentCapacity * 2;
+ }
+ }
+
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
new file mode 100644
index 0000000000000..bbe83d36cf36b
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.memory;
+
+/**
+ * A simple {@link MemoryAllocator} that can allocate up to 16GB using a JVM long primitive array.
+ */
+public class HeapMemoryAllocator implements MemoryAllocator {
+
+ @Override
+ public MemoryBlock allocate(long size) throws OutOfMemoryError {
+ long[] array = new long[(int) (size / 8)];
+ return MemoryBlock.fromLongArray(array);
+ }
+
+ @Override
+ public void free(MemoryBlock memory) {
+ // Do nothing
+ }
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java
new file mode 100644
index 0000000000000..5192f68c862cf
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryAllocator.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.memory;
+
+public interface MemoryAllocator {
+
+ /**
+ * Allocates a contiguous block of memory. Note that the allocated memory is not guaranteed
+ * to be zeroed out (call `zero()` on the result if this is necessary).
+ */
+ MemoryBlock allocate(long size) throws OutOfMemoryError;
+
+ void free(MemoryBlock memory);
+
+ MemoryAllocator UNSAFE = new UnsafeMemoryAllocator();
+
+ MemoryAllocator HEAP = new HeapMemoryAllocator();
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
new file mode 100644
index 0000000000000..49963cc099b29
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryBlock.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.memory;
+
+import javax.annotation.Nullable;
+
+import org.apache.spark.unsafe.PlatformDependent;
+
+/**
+ * A consecutive block of memory, starting at a {@link MemoryLocation} with a fixed size.
+ */
+public class MemoryBlock extends MemoryLocation {
+
+ private final long length;
+
+ /**
+ * Optional page number; used when this MemoryBlock represents a page allocated by a
+ * MemoryManager. This is package-private and is modified by MemoryManager.
+ */
+ int pageNumber = -1;
+
+ public int getPageNumber() {
+ return pageNumber;
+ }
+
+ MemoryBlock(@Nullable Object obj, long offset, long length) {
+ super(obj, offset);
+ this.length = length;
+ }
+
+ /**
+ * Returns the size of the memory block.
+ */
+ public long size() {
+ return length;
+ }
+
+ /**
+ * Clear the contents of this memory block. Returns `this` to facilitate chaining.
+ */
+ public MemoryBlock zero() {
+ PlatformDependent.UNSAFE.setMemory(obj, offset, length, (byte) 0);
+ return this;
+ }
+
+ /**
+ * Creates a memory block pointing to the memory used by the byte array.
+ */
+ public static MemoryBlock fromByteArray(final byte[] array) {
+ return new MemoryBlock(array, PlatformDependent.BYTE_ARRAY_OFFSET, array.length);
+ }
+
+ /**
+ * Creates a memory block pointing to the memory used by the long array.
+ */
+ public static MemoryBlock fromLongArray(final long[] array) {
+ return new MemoryBlock(array, PlatformDependent.LONG_ARRAY_OFFSET, array.length * 8);
+ }
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java
new file mode 100644
index 0000000000000..74ebc87dc978c
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryLocation.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.memory;
+
+import javax.annotation.Nullable;
+
+/**
+ * A memory location. Tracked either by a memory address (with off-heap allocation),
+ * or by an offset from a JVM object (in-heap allocation).
+ */
+public class MemoryLocation {
+
+ @Nullable
+ Object obj;
+
+ long offset;
+
+ public MemoryLocation(@Nullable Object obj, long offset) {
+ this.obj = obj;
+ this.offset = offset;
+ }
+
+ public MemoryLocation() {
+ this(null, 0);
+ }
+
+ public void setObjAndOffset(Object newObj, long newOffset) {
+ this.obj = newObj;
+ this.offset = newOffset;
+ }
+
+ public final Object getBaseObject() {
+ return obj;
+ }
+
+ public final long getBaseOffset() {
+ return offset;
+ }
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryManager.java
new file mode 100644
index 0000000000000..3b6c8b09f50e8
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/MemoryManager.java
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.memory;
+
+import java.util.BitSet;
+
+/**
+ * Manages the lifecycle of data pages exchanged between operators.
+ *
+ * Most of the complexity in this class deals with encoding of off-heap addresses into 64-bit longs.
+ * In off-heap mode, memory can be directly addressed with 64-bit longs. In on-heap mode, memory is
+ * addressed by the combination of a base Object reference and a 64-bit offset within that object.
+ * This is a problem when we want to store pointers to data structures inside of other structures,
+ * such as record pointers inside hashmaps or sorting buffers. Even if we decided to use 128 bits
+ * to address memory, we can't just store the address of the base object since it's not guaranteed
+ * to remain stable as the heap gets reorganized due to GC.
+ *
+ * Instead, we use the following approach to encode record pointers in 64-bit longs: for off-heap
+ * mode, just store the raw address, and for on-heap mode use the upper 13 bits of the address to
+ * store a "page number" and the lower 51 bits to store an offset within this page. These page
+ * numbers are used to index into a "page table" array inside of the MemoryManager in order to
+ * retrieve the base object.
+ */
+public final class MemoryManager {
+
+ /**
+ * The number of entries in the page table.
+ */
+ private static final int PAGE_TABLE_SIZE = (int) 1L << 13;
+
+ /** Bit mask for the lower 51 bits of a long. */
+ private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL;
+
+ /** Bit mask for the upper 13 bits of a long */
+ private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS;
+
+ /**
+ * Similar to an operating system's page table, this array maps page numbers into base object
+ * pointers, allowing us to translate between the hashtable's internal 64-bit address
+ * representation and the baseObject+offset representation which we use to support both in- and
+ * off-heap addresses. When using an off-heap allocator, every entry in this map will be `null`.
+ * When using an in-heap allocator, the entries in this map will point to pages' base objects.
+ * Entries are added to this map as new data pages are allocated.
+ */
+ private final MemoryBlock[] pageTable = new MemoryBlock[PAGE_TABLE_SIZE];
+
+ /**
+ * Bitmap for tracking free pages.
+ */
+ private final BitSet allocatedPages = new BitSet(PAGE_TABLE_SIZE);
+
+ /**
+ * Allocator, exposed for enabling untracked allocations of temporary data structures.
+ */
+ public final MemoryAllocator allocator;
+
+ /**
+ * Tracks whether we're in-heap or off-heap. For off-heap, we short-circuit most of these methods
+ * without doing any masking or lookups. Since this branching should be well-predicted by the JIT,
+ * this extra layer of indirection / abstraction hopefully shouldn't be too expensive.
+ */
+ private final boolean inHeap;
+
+ /**
+ * Construct a new MemoryManager.
+ */
+ public MemoryManager(MemoryAllocator allocator) {
+ this.inHeap = allocator instanceof HeapMemoryAllocator;
+ this.allocator = allocator;
+ }
+
+ /**
+ * Allocate a block of memory that will be tracked in the MemoryManager's page table; this is
+ * intended for allocating large blocks of memory that will be shared between operators.
+ */
+ public MemoryBlock allocatePage(long size) {
+ if (size >= (1L << 51)) {
+ throw new IllegalArgumentException("Cannot allocate a page with more than 2^51 bytes");
+ }
+
+ final int pageNumber;
+ synchronized (this) {
+ pageNumber = allocatedPages.nextClearBit(0);
+ if (pageNumber >= PAGE_TABLE_SIZE) {
+ throw new IllegalStateException(
+ "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages");
+ }
+ allocatedPages.set(pageNumber);
+ }
+ final MemoryBlock page = allocator.allocate(size);
+ page.pageNumber = pageNumber;
+ pageTable[pageNumber] = page;
+ return page;
+ }
+
+ /**
+ * Free a block of memory allocated via {@link MemoryManager#allocatePage(long)}.
+ */
+ public void freePage(MemoryBlock page) {
+ assert (page.pageNumber != -1) :
+ "Called freePage() on memory that wasn't allocated with allocatePage()";
+
+ allocator.free(page);
+ synchronized (this) {
+ allocatedPages.clear(page.pageNumber);
+ }
+ pageTable[page.pageNumber] = null;
+ }
+
+ /**
+ * Given a memory page and offset within that page, encode this address into a 64-bit long.
+ * This address will remain valid as long as the corresponding page has not been freed.
+ */
+ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) {
+ if (inHeap) {
+ assert (page.pageNumber != -1) : "encodePageNumberAndOffset called with invalid page";
+ return (((long) page.pageNumber) << 51) | (offsetInPage & MASK_LONG_LOWER_51_BITS);
+ } else {
+ return offsetInPage;
+ }
+ }
+
+ /**
+ * Get the page associated with an address encoded by
+ * {@link MemoryManager#encodePageNumberAndOffset(MemoryBlock, long)}
+ */
+ public Object getPage(long pagePlusOffsetAddress) {
+ if (inHeap) {
+ final int pageNumber = (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> 51);
+ assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
+ final Object page = pageTable[pageNumber].getBaseObject();
+ assert (page != null);
+ return page;
+ } else {
+ return null;
+ }
+ }
+
+ /**
+ * Get the offset associated with an address encoded by
+ * {@link MemoryManager#encodePageNumberAndOffset(MemoryBlock, long)}
+ */
+ public long getOffsetInPage(long pagePlusOffsetAddress) {
+ if (inHeap) {
+ return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS);
+ } else {
+ return pagePlusOffsetAddress;
+ }
+ }
+
+ /**
+ * Clean up all pages. This shouldn't be called in production code and is only exposed for tests.
+ */
+ public void cleanUpAllPages() {
+ for (MemoryBlock page : pageTable) {
+ if (page != null) {
+ freePage(page);
+ }
+ }
+ }
+}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
new file mode 100644
index 0000000000000..15898771fef25
--- /dev/null
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/UnsafeMemoryAllocator.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.memory;
+
+import org.apache.spark.unsafe.PlatformDependent;
+
+/**
+ * A simple {@link MemoryAllocator} that uses {@code Unsafe} to allocate off-heap memory.
+ */
+public class UnsafeMemoryAllocator implements MemoryAllocator {
+
+ @Override
+ public MemoryBlock allocate(long size) throws OutOfMemoryError {
+ long address = PlatformDependent.UNSAFE.allocateMemory(size);
+ return new MemoryBlock(null, address, size);
+ }
+
+ @Override
+ public void free(MemoryBlock memory) {
+ assert (memory.obj == null) :
+ "baseObject not null; are you trying to use the off-heap allocator to free on-heap memory?";
+ PlatformDependent.UNSAFE.freeMemory(memory.offset);
+ }
+}
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/array/TestLongArray.java b/unsafe/src/test/java/org/apache/spark/unsafe/array/TestLongArray.java
new file mode 100644
index 0000000000000..e49e344041ad7
--- /dev/null
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/array/TestLongArray.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.array;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.spark.unsafe.memory.MemoryBlock;
+
+public class TestLongArray {
+
+ private static LongArray createTestData() {
+ byte[] bytes = new byte[16];
+ LongArray arr = new LongArray(MemoryBlock.fromByteArray(bytes));
+ arr.set(0, 1L);
+ arr.set(1, 2L);
+ arr.set(1, 3L);
+ return arr;
+ }
+
+ @Test
+ public void basicTest() {
+ LongArray arr = createTestData();
+ Assert.assertEquals(2, arr.size());
+ Assert.assertEquals(1L, arr.get(0));
+ Assert.assertEquals(3L, arr.get(1));
+ }
+}
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/bitset/TestBitSet.java b/unsafe/src/test/java/org/apache/spark/unsafe/bitset/TestBitSet.java
new file mode 100644
index 0000000000000..fa84e404fd4d4
--- /dev/null
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/bitset/TestBitSet.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.bitset;
+
+import junit.framework.Assert;
+import org.apache.spark.unsafe.bitset.BitSet;
+import org.junit.Test;
+
+import org.apache.spark.unsafe.memory.MemoryBlock;
+
+public class TestBitSet {
+
+ private static BitSet createBitSet(int capacity) {
+ assert capacity % 64 == 0;
+ return new BitSet(MemoryBlock.fromLongArray(new long[capacity / 64]).zero());
+ }
+
+ @Test
+ public void basicOps() {
+ BitSet bs = createBitSet(64);
+ Assert.assertEquals(64, bs.capacity());
+
+ // Make sure the bit set starts empty.
+ for (int i = 0; i < bs.capacity(); i++) {
+ Assert.assertFalse(bs.isSet(i));
+ }
+
+ // Set every bit and check it.
+ for (int i = 0; i < bs.capacity(); i++) {
+ bs.set(i);
+ Assert.assertTrue(bs.isSet(i));
+ }
+
+ // Unset every bit and check it.
+ for (int i = 0; i < bs.capacity(); i++) {
+ Assert.assertTrue(bs.isSet(i));
+ bs.unset(i);
+ Assert.assertFalse(bs.isSet(i));
+ }
+ }
+
+ @Test
+ public void traversal() {
+ BitSet bs = createBitSet(256);
+
+ Assert.assertEquals(-1, bs.nextSetBit(0));
+ Assert.assertEquals(-1, bs.nextSetBit(10));
+ Assert.assertEquals(-1, bs.nextSetBit(64));
+
+ bs.set(10);
+ Assert.assertEquals(10, bs.nextSetBit(0));
+ Assert.assertEquals(10, bs.nextSetBit(1));
+ Assert.assertEquals(10, bs.nextSetBit(10));
+ Assert.assertEquals(-1, bs.nextSetBit(11));
+
+ bs.set(11);
+ Assert.assertEquals(10, bs.nextSetBit(10));
+ Assert.assertEquals(11, bs.nextSetBit(11));
+
+ // Skip a whole word and find it
+ bs.set(190);
+ Assert.assertEquals(190, bs.nextSetBit(12));
+
+ Assert.assertEquals(-1, bs.nextSetBit(191));
+ Assert.assertEquals(-1, bs.nextSetBit(256));
+ }
+}
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/hash/TestMurmur3_x86_32.java b/unsafe/src/test/java/org/apache/spark/unsafe/hash/TestMurmur3_x86_32.java
new file mode 100644
index 0000000000000..558cf4db87522
--- /dev/null
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/hash/TestMurmur3_x86_32.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.hash;
+
+import java.util.HashSet;
+import java.util.Random;
+import java.util.Set;
+
+import junit.framework.Assert;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.junit.Test;
+
+/**
+ * Test file based on Guava's Murmur3Hash32Test.
+ */
+public class TestMurmur3_x86_32 {
+
+ private static final Murmur3_x86_32 hasher = new Murmur3_x86_32(0);
+
+ @Test
+ public void testKnownIntegerInputs() {
+ Assert.assertEquals(593689054, hasher.hashInt(0));
+ Assert.assertEquals(-189366624, hasher.hashInt(-42));
+ Assert.assertEquals(-1134849565, hasher.hashInt(42));
+ Assert.assertEquals(-1718298732, hasher.hashInt(Integer.MIN_VALUE));
+ Assert.assertEquals(-1653689534, hasher.hashInt(Integer.MAX_VALUE));
+ }
+
+ @Test
+ public void testKnownLongInputs() {
+ Assert.assertEquals(1669671676, hasher.hashLong(0L));
+ Assert.assertEquals(-846261623, hasher.hashLong(-42L));
+ Assert.assertEquals(1871679806, hasher.hashLong(42L));
+ Assert.assertEquals(1366273829, hasher.hashLong(Long.MIN_VALUE));
+ Assert.assertEquals(-2106506049, hasher.hashLong(Long.MAX_VALUE));
+ }
+
+ @Test
+ public void randomizedStressTest() {
+ int size = 65536;
+ Random rand = new Random();
+
+ // A set used to track collision rate.
+ Set hashcodes = new HashSet();
+ for (int i = 0; i < size; i++) {
+ int vint = rand.nextInt();
+ long lint = rand.nextLong();
+ Assert.assertEquals(hasher.hashInt(vint), hasher.hashInt(vint));
+ Assert.assertEquals(hasher.hashLong(lint), hasher.hashLong(lint));
+
+ hashcodes.add(hasher.hashLong(lint));
+ }
+
+ // A very loose bound.
+ Assert.assertTrue(hashcodes.size() > size * 0.95);
+ }
+
+ @Test
+ public void randomizedStressTestBytes() {
+ int size = 65536;
+ Random rand = new Random();
+
+ // A set used to track collision rate.
+ Set hashcodes = new HashSet();
+ for (int i = 0; i < size; i++) {
+ int byteArrSize = rand.nextInt(100) * 8;
+ byte[] bytes = new byte[byteArrSize];
+ rand.nextBytes(bytes);
+
+ Assert.assertEquals(
+ hasher.hashUnsafeWords(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize),
+ hasher.hashUnsafeWords(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize));
+
+ hashcodes.add(hasher.hashUnsafeWords(
+ bytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize));
+ }
+
+ // A very loose bound.
+ Assert.assertTrue(hashcodes.size() > size * 0.95);
+ }
+
+ @Test
+ public void randomizedStressTestPaddedStrings() {
+ int size = 64000;
+ // A set used to track collision rate.
+ Set hashcodes = new HashSet();
+ for (int i = 0; i < size; i++) {
+ int byteArrSize = 8;
+ byte[] strBytes = ("" + i).getBytes();
+ byte[] paddedBytes = new byte[byteArrSize];
+ System.arraycopy(strBytes, 0, paddedBytes, 0, strBytes.length);
+
+ Assert.assertEquals(
+ hasher.hashUnsafeWords(paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize),
+ hasher.hashUnsafeWords(paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize));
+
+ hashcodes.add(hasher.hashUnsafeWords(
+ paddedBytes, PlatformDependent.BYTE_ARRAY_OFFSET, byteArrSize));
+ }
+
+ // A very loose bound.
+ Assert.assertTrue(hashcodes.size() > size * 0.95);
+ }
+}
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractTestBytesToBytesMap.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractTestBytesToBytesMap.java
new file mode 100644
index 0000000000000..48abf605b7bdb
--- /dev/null
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/AbstractTestBytesToBytesMap.java
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.map;
+
+import java.lang.Exception;
+import java.nio.ByteBuffer;
+import java.util.*;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.MemoryLocation;
+import org.apache.spark.unsafe.memory.MemoryManager;
+import static org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET;
+
+public abstract class AbstractTestBytesToBytesMap {
+
+ private final Random rand = new Random(42);
+
+ private MemoryManager memoryManager;
+
+ @Before
+ public void setup() {
+ memoryManager = new MemoryManager(getMemoryAllocator());
+ }
+
+ @After
+ public void tearDown() {
+ if (memoryManager != null) {
+ memoryManager.cleanUpAllPages();
+ memoryManager = null;
+ }
+ }
+
+ protected abstract MemoryAllocator getMemoryAllocator();
+
+ private static byte[] getByteArray(MemoryLocation loc, int size) {
+ final byte[] arr = new byte[size];
+ PlatformDependent.UNSAFE.copyMemory(
+ loc.getBaseObject(),
+ loc.getBaseOffset(),
+ arr,
+ BYTE_ARRAY_OFFSET,
+ size
+ );
+ return arr;
+ }
+
+ private byte[] getRandomByteArray(int numWords) {
+ Assert.assertTrue(numWords > 0);
+ final int lengthInBytes = numWords * 8;
+ final byte[] bytes = new byte[lengthInBytes];
+ rand.nextBytes(bytes);
+ return bytes;
+ }
+
+ /**
+ * Fast equality checking for byte arrays, since these comparisons are a bottleneck
+ * in our stress tests.
+ */
+ private static boolean arrayEquals(
+ byte[] expected,
+ MemoryLocation actualAddr,
+ long actualLengthBytes) {
+ return (actualLengthBytes == expected.length) && ByteArrayMethods.wordAlignedArrayEquals(
+ expected,
+ BYTE_ARRAY_OFFSET,
+ actualAddr.getBaseObject(),
+ actualAddr.getBaseOffset(),
+ expected.length
+ );
+ }
+
+ @Test
+ public void emptyMap() {
+ BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64);
+ Assert.assertEquals(0, map.size());
+ final int keyLengthInWords = 10;
+ final int keyLengthInBytes = keyLengthInWords * 8;
+ final byte[] key = getRandomByteArray(keyLengthInWords);
+ Assert.assertFalse(map.lookup(key, BYTE_ARRAY_OFFSET, keyLengthInBytes).isDefined());
+ }
+
+ @Test
+ public void setAndRetrieveAKey() {
+ BytesToBytesMap map = new BytesToBytesMap(memoryManager, 64);
+ final int recordLengthWords = 10;
+ final int recordLengthBytes = recordLengthWords * 8;
+ final byte[] keyData = getRandomByteArray(recordLengthWords);
+ final byte[] valueData = getRandomByteArray(recordLengthWords);
+ try {
+ final BytesToBytesMap.Location loc =
+ map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes);
+ Assert.assertFalse(loc.isDefined());
+ loc.putNewKey(
+ keyData,
+ BYTE_ARRAY_OFFSET,
+ recordLengthBytes,
+ valueData,
+ BYTE_ARRAY_OFFSET,
+ recordLengthBytes
+ );
+ // After storing the key and value, the other location methods should return results that
+ // reflect the result of this store without us having to call lookup() again on the same key.
+ Assert.assertEquals(recordLengthBytes, loc.getKeyLength());
+ Assert.assertEquals(recordLengthBytes, loc.getValueLength());
+ Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes));
+ Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes));
+
+ // After calling lookup() the location should still point to the correct data.
+ Assert.assertTrue(map.lookup(keyData, BYTE_ARRAY_OFFSET, recordLengthBytes).isDefined());
+ Assert.assertEquals(recordLengthBytes, loc.getKeyLength());
+ Assert.assertEquals(recordLengthBytes, loc.getValueLength());
+ Assert.assertArrayEquals(keyData, getByteArray(loc.getKeyAddress(), recordLengthBytes));
+ Assert.assertArrayEquals(valueData, getByteArray(loc.getValueAddress(), recordLengthBytes));
+
+ try {
+ loc.putNewKey(
+ keyData,
+ BYTE_ARRAY_OFFSET,
+ recordLengthBytes,
+ valueData,
+ BYTE_ARRAY_OFFSET,
+ recordLengthBytes
+ );
+ Assert.fail("Should not be able to set a new value for a key");
+ } catch (AssertionError e) {
+ // Expected exception; do nothing.
+ }
+ } finally {
+ map.free();
+ }
+ }
+
+ @Test
+ public void iteratorTest() throws Exception {
+ final int size = 128;
+ BytesToBytesMap map = new BytesToBytesMap(memoryManager, size / 2);
+ try {
+ for (long i = 0; i < size; i++) {
+ final long[] value = new long[] { i };
+ final BytesToBytesMap.Location loc =
+ map.lookup(value, PlatformDependent.LONG_ARRAY_OFFSET, 8);
+ Assert.assertFalse(loc.isDefined());
+ loc.putNewKey(
+ value,
+ PlatformDependent.LONG_ARRAY_OFFSET,
+ 8,
+ value,
+ PlatformDependent.LONG_ARRAY_OFFSET,
+ 8
+ );
+ }
+ final java.util.BitSet valuesSeen = new java.util.BitSet(size);
+ final Iterator iter = map.iterator();
+ while (iter.hasNext()) {
+ final BytesToBytesMap.Location loc = iter.next();
+ Assert.assertTrue(loc.isDefined());
+ final MemoryLocation keyAddress = loc.getKeyAddress();
+ final MemoryLocation valueAddress = loc.getValueAddress();
+ final long key = PlatformDependent.UNSAFE.getLong(
+ keyAddress.getBaseObject(), keyAddress.getBaseOffset());
+ final long value = PlatformDependent.UNSAFE.getLong(
+ valueAddress.getBaseObject(), valueAddress.getBaseOffset());
+ Assert.assertEquals(key, value);
+ valuesSeen.set((int) value);
+ }
+ Assert.assertEquals(size, valuesSeen.cardinality());
+ } finally {
+ map.free();
+ }
+ }
+
+ @Test
+ public void randomizedStressTest() {
+ final int size = 65536;
+ // Java arrays' hashCodes() aren't based on the arrays' contents, so we need to wrap arrays
+ // into ByteBuffers in order to use them as keys here.
+ final Map expected = new HashMap();
+ final BytesToBytesMap map = new BytesToBytesMap(memoryManager, size);
+
+ try {
+ // Fill the map to 90% full so that we can trigger probing
+ for (int i = 0; i < size * 0.9; i++) {
+ final byte[] key = getRandomByteArray(rand.nextInt(256) + 1);
+ final byte[] value = getRandomByteArray(rand.nextInt(512) + 1);
+ if (!expected.containsKey(ByteBuffer.wrap(key))) {
+ expected.put(ByteBuffer.wrap(key), value);
+ final BytesToBytesMap.Location loc = map.lookup(
+ key,
+ BYTE_ARRAY_OFFSET,
+ key.length
+ );
+ Assert.assertFalse(loc.isDefined());
+ loc.putNewKey(
+ key,
+ BYTE_ARRAY_OFFSET,
+ key.length,
+ value,
+ BYTE_ARRAY_OFFSET,
+ value.length
+ );
+ // After calling putNewKey, the following should be true, even before calling
+ // lookup():
+ Assert.assertTrue(loc.isDefined());
+ Assert.assertEquals(key.length, loc.getKeyLength());
+ Assert.assertEquals(value.length, loc.getValueLength());
+ Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), key.length));
+ Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), value.length));
+ }
+ }
+
+ for (Map.Entry entry : expected.entrySet()) {
+ final byte[] key = entry.getKey().array();
+ final byte[] value = entry.getValue();
+ final BytesToBytesMap.Location loc = map.lookup(key, BYTE_ARRAY_OFFSET, key.length);
+ Assert.assertTrue(loc.isDefined());
+ Assert.assertTrue(arrayEquals(key, loc.getKeyAddress(), loc.getKeyLength()));
+ Assert.assertTrue(arrayEquals(value, loc.getValueAddress(), loc.getValueLength()));
+ }
+ } finally {
+ map.free();
+ }
+ }
+}
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/TestBytesToBytesMapOffHeap.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/TestBytesToBytesMapOffHeap.java
new file mode 100644
index 0000000000000..c52a5d59ea6d6
--- /dev/null
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/TestBytesToBytesMapOffHeap.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.map;
+
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+
+public class TestBytesToBytesMapOffHeap extends AbstractTestBytesToBytesMap {
+
+ @Override
+ protected MemoryAllocator getMemoryAllocator() {
+ return MemoryAllocator.UNSAFE;
+ }
+
+}
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/map/TestBytesToBytesMapOnHeap.java b/unsafe/src/test/java/org/apache/spark/unsafe/map/TestBytesToBytesMapOnHeap.java
new file mode 100644
index 0000000000000..9fb412d9fae07
--- /dev/null
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/map/TestBytesToBytesMapOnHeap.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.unsafe.map;
+
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+
+public class TestBytesToBytesMapOnHeap extends AbstractTestBytesToBytesMap {
+
+ @Override
+ protected MemoryAllocator getMemoryAllocator() {
+ return MemoryAllocator.HEAP;
+ }
+
+}