diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java index afea4676893e..f8bbfbbac1b7 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions.codegen; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.GenericArrayData; import org.apache.spark.sql.types.Decimal; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; @@ -259,4 +261,116 @@ public void write(int ordinal, CalendarInterval input) { // move the cursor forward. holder.cursor += 16; } + + public void writePrimitiveBooleanArray(ArrayData arrayData) { + boolean[] input; + if (arrayData instanceof GenericArrayData && + (input = ((GenericArrayData)arrayData).booleanArray()) != null) { + int length = input.length; + Platform.copyMemory(input, Platform.BOOLEAN_ARRAY_OFFSET, + holder.buffer, startingOffset + headerInBytes, length); + } else { + int length = arrayData.numElements(); + for (int i = 0; i < length; i++) { + Platform.putBoolean(holder.buffer, startingOffset + headerInBytes + i, + arrayData.getBoolean(i)); + } + } + } + + public void writePrimitiveByteArray(ArrayData arrayData) { + byte[] input; + if (arrayData instanceof GenericArrayData && + (input = ((GenericArrayData)arrayData).byteArray()) != null) { + int length = input.length; + Platform.copyMemory(input, Platform.BYTE_ARRAY_OFFSET, + holder.buffer, startingOffset + headerInBytes, length); + } else { + int length = arrayData.numElements(); + for (int i = 0; i < length; i++) { + Platform.putByte(holder.buffer, startingOffset + headerInBytes + i, + arrayData.getByte(i)); + } + } + } + + public void writePrimitiveShortArray(ArrayData arrayData) { + short[] input; + if (arrayData instanceof GenericArrayData && + (input = ((GenericArrayData)arrayData).shortArray()) != null) { + int length = input.length * 2; + Platform.copyMemory(input, Platform.SHORT_ARRAY_OFFSET, + holder.buffer, startingOffset + headerInBytes, length); + } else { + int length = arrayData.numElements(); + for (int i = 0; i < length; i++) { + Platform.putShort(holder.buffer, startingOffset + headerInBytes + i * 2, + arrayData.getShort(i)); + } + } + } + + public void writePrimitiveIntArray(ArrayData arrayData) { + int[] input; + if (arrayData instanceof GenericArrayData && + (input = ((GenericArrayData)arrayData).intArray()) != null) { + int length = input.length * 4; + Platform.copyMemory(input, Platform.INT_ARRAY_OFFSET, + holder.buffer, startingOffset + headerInBytes, length); + } else { + int length = arrayData.numElements(); + for (int i = 0; i < length; i++) { + Platform.putInt(holder.buffer, startingOffset + headerInBytes + i * 4, + arrayData.getInt(i)); + } + } + } + + public void writePrimitiveLongArray(ArrayData arrayData) { + long[] input; + if (arrayData instanceof GenericArrayData && + (input = ((GenericArrayData)arrayData).longArray()) != null) { + int length = input.length * 8; + Platform.copyMemory(input, Platform.LONG_ARRAY_OFFSET, + holder.buffer, startingOffset + headerInBytes, length); + } else { + int length = arrayData.numElements(); + for (int i = 0; i < length; i++) { + Platform.putLong(holder.buffer, startingOffset + headerInBytes + i * 8, + arrayData.getLong(i)); + } + } + } + + public void writePrimitiveFloatArray(ArrayData arrayData) { + float[] input; + if (arrayData instanceof GenericArrayData && + (input = ((GenericArrayData)arrayData).floatArray()) != null) { + int length = input.length * 4; + Platform.copyMemory(input, Platform.FLOAT_ARRAY_OFFSET, + holder.buffer, startingOffset + headerInBytes, length); + } else { + int length = arrayData.numElements(); + for (int i = 0; i < length; i++) { + Platform.putFloat(holder.buffer, startingOffset + headerInBytes + i * 4, + arrayData.getFloat(i)); + } + } + } + + public void writePrimitiveDoubleArray(ArrayData arrayData) { + double[] input; + if (arrayData instanceof GenericArrayData && + (input = ((GenericArrayData)arrayData).doubleArray()) != null) { + int length = input.length * 8; + Platform.copyMemory(input, Platform.DOUBLE_ARRAY_OFFSET, + holder.buffer, startingOffset + headerInBytes, length); + } else { + int length = arrayData.numElements(); + for (int i = 0; i < length; i++) { + Platform.putDouble(holder.buffer, startingOffset + headerInBytes + i * 8, + arrayData.getDouble(i)); + } + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 7e4c9089a2cb..0854990f5ef8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -117,12 +117,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ - case a @ ArrayType(et, _) => + case a @ ArrayType(et, cn) => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. final int $tmpCursor = $bufferHolder.cursor; - ${writeArrayToBuffer(ctx, input.value, et, bufferHolder)} + ${writeArrayToBuffer(ctx, input.value, et, cn, bufferHolder)} $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ @@ -171,6 +171,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx: CodegenContext, input: String, elementType: DataType, + containsNull: Boolean, bufferHolder: String): String = { val arrayWriterClass = classOf[UnsafeArrayWriter].getName val arrayWriter = ctx.freshName("arrayWriter") @@ -202,10 +203,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ - case a @ ArrayType(et, _) => + case a @ ArrayType(et, cn) => s""" final int $tmpCursor = $bufferHolder.cursor; - ${writeArrayToBuffer(ctx, element, et, bufferHolder)} + ${writeArrayToBuffer(ctx, element, et, cn, bufferHolder)} $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ @@ -225,6 +226,31 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } val primitiveTypeName = if (ctx.isPrimitiveType(jt)) ctx.primitiveTypeName(et) else "" + val storeElements = if (containsNull) { + s""" + for (int $index = 0; $index < $numElements; $index++) { + if ($input.isNullAt($index)) { + $arrayWriter.setNull$primitiveTypeName($index); + } else { + final $jt $element = ${ctx.getValue(input, et, index)}; + $writeElement + } + } + """ + } else { + if (ctx.isPrimitiveType(et)) { + val typeName = ctx.primitiveTypeName(et) + s"$arrayWriter.writePrimitive${typeName}Array($input);" + } else { + s""" + for (int $index = 0; $index < $numElements; $index++) { + final $jt $element = ${ctx.getValue(input, et, index)}; + $writeElement + } + """ + } + } + s""" if ($input instanceof UnsafeArrayData) { ${writeUnsafeData(ctx, s"((UnsafeArrayData) $input)", bufferHolder)} @@ -232,14 +258,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro final int $numElements = $input.numElements(); $arrayWriter.initialize($bufferHolder, $numElements, $elementOrOffsetSize); - for (int $index = 0; $index < $numElements; $index++) { - if ($input.isNullAt($index)) { - $arrayWriter.setNull$primitiveTypeName($index); - } else { - final $jt $element = ${ctx.getValue(input, et, index)}; - $writeElement - } - } + $storeElements } """ } @@ -271,11 +290,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Remember the current cursor so that we can write numBytes of key array later. final int $tmpCursor = $bufferHolder.cursor; - ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder)} + ${writeArrayToBuffer(ctx, keys, keyType, false, bufferHolder)} // Write the numBytes of key array into the first 8 bytes. Platform.putLong($bufferHolder.buffer, $tmpCursor - 8, $bufferHolder.cursor - $tmpCursor); - ${writeArrayToBuffer(ctx, values, valueType, bufferHolder)} + ${writeArrayToBuffer(ctx, values, valueType, true, bufferHolder)} } """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index 7ee9581b63af..34236ad6f4b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.util import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types.{DataType, Decimal} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} private object GenericArrayData { @@ -33,36 +33,72 @@ private object GenericArrayData { } -class GenericArrayData(val array: Array[Any]) extends ArrayData { +class GenericArrayData(val array: Array[Any], + val booleanArray: Array[Boolean], val byteArray: Array[Byte], val shortArray: Array[Short], + val intArray: Array[Int], val longArray: Array[Long], val floatArray: Array[Float], + val doubleArray: Array[Double]) extends ArrayData { - def this(seq: Seq[Any]) = this(seq.toArray) + def this(seq: Seq[Any]) = this(seq.toArray, null, null, null, null, null, null, null) def this(list: java.util.List[Any]) = this(list.asScala) // TODO: This is boxing. We should specialize. - def this(primitiveArray: Array[Int]) = this(primitiveArray.toSeq) - def this(primitiveArray: Array[Long]) = this(primitiveArray.toSeq) - def this(primitiveArray: Array[Float]) = this(primitiveArray.toSeq) - def this(primitiveArray: Array[Double]) = this(primitiveArray.toSeq) - def this(primitiveArray: Array[Short]) = this(primitiveArray.toSeq) - def this(primitiveArray: Array[Byte]) = this(primitiveArray.toSeq) - def this(primitiveArray: Array[Boolean]) = this(primitiveArray.toSeq) + def this(primitiveArray: Array[Boolean]) = + this(null, primitiveArray, null, null, null, null, null, null) + def this(primitiveArray: Array[Byte]) = + this(null, null, primitiveArray, null, null, null, null, null) + def this(primitiveArray: Array[Short]) = + this(null, null, null, primitiveArray, null, null, null, null) + def this(primitiveArray: Array[Int]) = + this(null, null, null, null, primitiveArray, null, null, null) + def this(primitiveArray: Array[Long]) = + this(null, null, null, null, null, primitiveArray, null, null) + def this(primitiveArray: Array[Float]) = + this(null, null, null, null, null, null, primitiveArray, null) + def this(primitiveArray: Array[Double]) = + this(null, null, null, null, null, null, null, primitiveArray) + + def this(array: Array[Any]) = this(array, null, null, null, null, null, null, null) def this(seqOrArray: Any) = this(GenericArrayData.anyToSeq(seqOrArray)) - override def copy(): ArrayData = new GenericArrayData(array.clone()) + override def copy(): ArrayData = { + if (booleanArray != null) new GenericArrayData(booleanArray.clone()) + else if (byteArray != null) new GenericArrayData(byteArray.clone()) + else if (shortArray != null) new GenericArrayData(shortArray.clone()) + else if (intArray != null) new GenericArrayData(intArray.clone()) + else if (longArray != null) new GenericArrayData(longArray.clone()) + else if (floatArray != null) new GenericArrayData(floatArray.clone()) + else if (doubleArray != null) new GenericArrayData(doubleArray.clone()) + else new GenericArrayData(array.clone()) + } - override def numElements(): Int = array.length + override def numElements(): Int = { + if (booleanArray != null) booleanArray.length + else if (byteArray != null) byteArray.length + else if (shortArray != null) shortArray.length + else if (intArray != null) intArray.length + else if (longArray != null) longArray.length + else if (floatArray != null) floatArray.length + else if (doubleArray != null) doubleArray.length + else array.length + } private def getAs[T](ordinal: Int) = array(ordinal).asInstanceOf[T] - override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null override def get(ordinal: Int, elementType: DataType): AnyRef = getAs(ordinal) - override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) - override def getByte(ordinal: Int): Byte = getAs(ordinal) - override def getShort(ordinal: Int): Short = getAs(ordinal) - override def getInt(ordinal: Int): Int = getAs(ordinal) - override def getLong(ordinal: Int): Long = getAs(ordinal) - override def getFloat(ordinal: Int): Float = getAs(ordinal) - override def getDouble(ordinal: Int): Double = getAs(ordinal) + override def getBoolean(ordinal: Int): Boolean = + if (booleanArray != null) booleanArray(ordinal) else getAs(ordinal) + override def getByte(ordinal: Int): Byte = + if (byteArray != null) byteArray(ordinal) else getAs(ordinal) + override def getShort(ordinal: Int): Short = + if (shortArray != null) shortArray(ordinal) else getAs(ordinal) + override def getInt(ordinal: Int): Int = + if (intArray != null) intArray(ordinal) else getAs(ordinal) + override def getLong(ordinal: Int): Long = + if (longArray != null) longArray(ordinal) else getAs(ordinal) + override def getFloat(ordinal: Int): Float = + if (floatArray != null) floatArray(ordinal) else getAs(ordinal) + override def getDouble(ordinal: Int): Double = + if (doubleArray != null) doubleArray(ordinal) else getAs(ordinal) override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) @@ -71,7 +107,102 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData { override def getArray(ordinal: Int): ArrayData = getAs(ordinal) override def getMap(ordinal: Int): MapData = getAs(ordinal) - override def toString(): String = array.mkString("[", ",", "]") + override def isNullAt(ordinal: Int): Boolean = { + if (booleanArray != null || byteArray != null || shortArray != null || intArray != null || + longArray != null || floatArray != null || doubleArray != null) { + false + } else { + getAs[AnyRef](ordinal) eq null + } + } + + override def toBooleanArray(): Array[Boolean] = { + if (booleanArray != null) { + val len = booleanArray.length + val array = new Array[Boolean](len) + System.arraycopy(booleanArray, 0, array, 0, len) + array + } else { + super.toBooleanArray + } + } + + override def toByteArray(): Array[Byte] = { + if (byteArray != null) { + val len = byteArray.length + val array = new Array[Byte](len) + System.arraycopy(byteArray, 0, array, 0, len) + array + } else { + super.toByteArray + } + } + + override def toShortArray(): Array[Short] = { + if (shortArray != null) { + val len = shortArray.length + val array = new Array[Short](len) + System.arraycopy(shortArray, 0, array, 0, len) + array + } else { + super.toShortArray + } + } + + override def toIntArray(): Array[Int] = { + if (intArray != null) { + val len = intArray.length + val array = new Array[Int](len) + System.arraycopy(intArray, 0, array, 0, len) + array + } else { + super.toIntArray + } + } + + override def toLongArray(): Array[Long] = { + if (longArray != null) { + val len = longArray.length + val array = new Array[Long](len) + System.arraycopy(longArray, 0, array, 0, len) + array + } else { + super.toLongArray + } + } + + override def toFloatArray(): Array[Float] = { + if (floatArray != null) { + val len = floatArray.length + val array = new Array[Float](len) + System.arraycopy(floatArray, 0, array, 0, len) + array + } else { + super.toFloatArray + } + } + + override def toDoubleArray(): Array[Double] = { + if (doubleArray != null) { + val len = doubleArray.length + val array = new Array[Double](len) + System.arraycopy(doubleArray, 0, array, 0, len) + array + } else { + super.toDoubleArray + } + } + + override def toString(): String = { + if (booleanArray != null) booleanArray.mkString("[", ",", "]") + else if (byteArray != null) byteArray.mkString("[", ",", "]") + else if (shortArray != null) shortArray.mkString("[", ",", "]") + else if (intArray != null) intArray.mkString("[", ",", "]") + else if (longArray != null) longArray.mkString("[", ",", "]") + else if (floatArray != null) floatArray.mkString("[", ",", "]") + else if (doubleArray != null) doubleArray.mkString("[", ",", "]") + else array.mkString("[", ",", "]") + } override def equals(o: Any): Boolean = { if (!o.isInstanceOf[GenericArrayData]) { @@ -88,6 +219,26 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData { return false } + if ((booleanArray != null) && (other.booleanArray != null)) { + return java.util.Arrays.equals(booleanArray, other.booleanArray) + } else if ((byteArray != null) && (other.byteArray != null)) { + return java.util.Arrays.equals(byteArray, other.byteArray) + } else if ((shortArray != null) && (other.shortArray != null)) { + return java.util.Arrays.equals(shortArray, other.shortArray) + } else if ((intArray != null) && (other.intArray != null)) { + return java.util.Arrays.equals(intArray, other.intArray) + } else if ((longArray != null) && (other.longArray != null)) { + return java.util.Arrays.equals(longArray, other.longArray) + } else if ((floatArray != null) && (other.floatArray != null)) { + return java.util.Arrays.equals(floatArray, other.floatArray) + } else if ((doubleArray != null) && (other.doubleArray != null)) { + return java.util.Arrays.equals(doubleArray, other.doubleArray) + } + + if ((array == null) || (other.array == null)) { + return false + } + var i = 0 while (i < len) { if (isNullAt(i) != other.isNullAt(i)) { @@ -121,6 +272,14 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData { } override def hashCode: Int = { + if (booleanArray != null) return java.util.Arrays.hashCode(booleanArray) + else if (byteArray != null) return java.util.Arrays.hashCode(byteArray) + else if (shortArray != null) return java.util.Arrays.hashCode(shortArray) + else if (intArray != null) return java.util.Arrays.hashCode(intArray) + else if (longArray != null) return java.util.Arrays.hashCode(longArray) + else if (floatArray != null) return java.util.Arrays.hashCode(floatArray) + else if (doubleArray != null) return java.util.Arrays.hashCode(doubleArray) + var result: Int = 37 var i = 0 val len = numElements() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala index c7c386b5b838..0653302e09c0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeRow} +import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.unsafe.Platform class BufferHolderSuite extends SparkFunSuite { @@ -36,4 +38,190 @@ class BufferHolderSuite extends SparkFunSuite { } assert(e.getMessage.contains("exceeds size limitation")) } + + def performUnsafeArrayWriter(length: Int, elementSize: Int, f: (UnsafeArrayWriter) => Unit): + UnsafeArrayData = { + val unsafeRow = new UnsafeRow(1) + val unsafeArrayWriter = new UnsafeArrayWriter + val bufferHolder = new BufferHolder(unsafeRow, 32) + bufferHolder.reset() + val cursor = bufferHolder.cursor + unsafeArrayWriter.initialize(bufferHolder, length, elementSize) + // execute UnsafeArrayWriter.foo() in f() + f(unsafeArrayWriter) + + val unsafeArray = new UnsafeArrayData + unsafeArray.pointTo(bufferHolder.buffer, cursor.toLong, bufferHolder.cursor - cursor) + assert(unsafeArray.numElements() == length) + unsafeArray + } + + def initializeUnsafeArrayData(data: Seq[Any], elementSize: Int): UnsafeArrayData = { + val length = data.length + val unsafeArray = new UnsafeArrayData + val headerSize = UnsafeArrayData.calculateHeaderPortionInBytes(length) + val size = headerSize + elementSize * length + val buffer = new Array[Byte](size) + Platform.putInt(buffer, Platform.BYTE_ARRAY_OFFSET, length) + unsafeArray.pointTo(buffer, Platform.BYTE_ARRAY_OFFSET, size) + assert(unsafeArray.numElements == length) + data.zipWithIndex.map { case (e, i) => + val offset = Platform.BYTE_ARRAY_OFFSET + headerSize + elementSize * i + e match { + case _ : Boolean => Platform.putBoolean(buffer, offset, e.asInstanceOf[Boolean]) + case _ : Byte => Platform.putByte(buffer, offset, e.asInstanceOf[Byte]) + case _ : Short => Platform.putShort(buffer, offset, e.asInstanceOf[Short]) + case _ : Int => Platform.putInt(buffer, offset, e.asInstanceOf[Int]) + case _ : Long => Platform.putLong(buffer, offset, e.asInstanceOf[Long]) + case _ : Float => Platform.putFloat(buffer, offset, e.asInstanceOf[Float]) + case _ : Double => Platform.putDouble(buffer, offset, e.asInstanceOf[Double]) + case _ => throw new UnsupportedOperationException() + } + } + unsafeArray + } + + val booleanData = Seq(true, false) + val byteData = Seq(0.toByte, 1.toByte, Byte.MaxValue, Byte.MinValue) + val shortData = Seq(0.toShort, 1.toShort, Short.MaxValue, Short.MinValue) + val intData = Seq(0, 1, -1, Int.MaxValue, Int.MinValue) + val longData = Seq(0.toLong, 1.toLong, -1.toLong, Long.MaxValue, Long.MinValue) + val floatData = Seq(0.toFloat, 1.1.toFloat, -1.1.toFloat, Float.MaxValue, Float.MinValue) + val doubleData = Seq(0.toDouble, 1.1.toDouble, -1.1.toDouble, Double.MaxValue, Double.MinValue) + + test("UnsafeArrayDataWriter write") { + val boolUnsafeArray = performUnsafeArrayWriter(booleanData.length, 1, + (writer: UnsafeArrayWriter) => booleanData.zipWithIndex.map { + case (e, i) => writer.write(i, e) }) + booleanData.zipWithIndex.map { case (e, i) => assert(boolUnsafeArray.getBoolean(i) == e) } + + val byteUnsafeArray = performUnsafeArrayWriter(byteData.length, 1, + (writer: UnsafeArrayWriter) => byteData.zipWithIndex.map { + case (e, i) => writer.write(i, e) }) + byteData.zipWithIndex.map { case (e, i) => assert(byteUnsafeArray.getByte(i) == e) } + + val shortUnsafeArray = performUnsafeArrayWriter(shortData.length, 2, + (writer: UnsafeArrayWriter) => shortData.zipWithIndex.map { + case (e, i) => writer.write(i, e) }) + shortData.zipWithIndex.map { case (e, i) => assert(shortUnsafeArray.getShort(i) == e) } + + val intUnsafeArray = performUnsafeArrayWriter(intData.length, 4, + (writer: UnsafeArrayWriter) => intData.zipWithIndex.map { + case (e, i) => writer.write(i, e) }) + intData.zipWithIndex.map { case (e, i) => assert(intUnsafeArray.getInt(i) == e) } + + val longUnsafeArray = performUnsafeArrayWriter(longData.length, 8, + (writer: UnsafeArrayWriter) => longData.zipWithIndex.map { + case (e, i) => writer.write(i, e) }) + longData.zipWithIndex.map { case (e, i) => assert(longUnsafeArray.getLong(i) == e) } + + val floatUnsafeArray = performUnsafeArrayWriter(floatData.length, 8, + (writer: UnsafeArrayWriter) => floatData.zipWithIndex.map { + case (e, i) => writer.write(i, e) }) + floatData.zipWithIndex.map { case (e, i) => assert(floatUnsafeArray.getFloat(i) == e) } + + val doubleUnsafeArray = performUnsafeArrayWriter(doubleData.length, 8, + (writer: UnsafeArrayWriter) => doubleData.zipWithIndex.map { + case (e, i) => writer.write(i, e) }) + doubleData.zipWithIndex.map { case (e, i) => assert(doubleUnsafeArray.getDouble(i) == e) } + } + + test("toPrimitiveArray") { + val booleanUnsafeArray = initializeUnsafeArrayData(booleanData, 1) + booleanUnsafeArray.toBooleanArray(). + zipWithIndex.map { case (e, i) => assert(e == booleanData(i)) } + + val byteUnsafeArray = initializeUnsafeArrayData(byteData, 1) + byteUnsafeArray.toByteArray().zipWithIndex.map { case (e, i) => assert(e == byteData(i)) } + + val shortUnsafeArray = initializeUnsafeArrayData(shortData, 2) + shortUnsafeArray.toShortArray().zipWithIndex.map { case (e, i) => assert(e == shortData(i)) } + + val intUnsafeArray = initializeUnsafeArrayData(intData, 4) + intUnsafeArray.toIntArray().zipWithIndex.map { case (e, i) => assert(e == intData(i)) } + + val longUnsafeArray = initializeUnsafeArrayData(longData, 8) + longUnsafeArray.toLongArray().zipWithIndex.map { case (e, i) => assert(e == longData(i)) } + + val floatUnsafeArray = initializeUnsafeArrayData(floatData, 4) + floatUnsafeArray.toFloatArray().zipWithIndex.map { case (e, i) => assert(e == floatData(i)) } + + val doubleUnsafeArray = initializeUnsafeArrayData(doubleData, 8) + doubleUnsafeArray.toDoubleArray(). + zipWithIndex.map { case (e, i) => assert(e == doubleData(i)) } + } + + test("fromPrimitiveArray") { + val booleanArray = booleanData.toArray + val booleanUnsafeArray = UnsafeArrayData.fromPrimitiveArray(booleanArray) + booleanArray.zipWithIndex.map { case (e, i) => assert(booleanUnsafeArray.getBoolean(i) == e) } + + val byteArray = byteData.toArray + val byteUnsafeArray = UnsafeArrayData.fromPrimitiveArray(byteArray) + byteArray.zipWithIndex.map { case (e, i) => assert(byteUnsafeArray.getByte(i) == e) } + + val shortArray = shortData.toArray + val shortUnsafeArray = UnsafeArrayData.fromPrimitiveArray(shortArray) + shortArray.zipWithIndex.map { case (e, i) => assert(shortUnsafeArray.getShort(i) == e) } + + val intArray = intData.toArray + val intUnsafeArray = UnsafeArrayData.fromPrimitiveArray(intArray) + intArray.zipWithIndex.map { case (e, i) => assert(intUnsafeArray.getInt(i) == e) } + + val longArray = longData.toArray + val longUnsafeArray = UnsafeArrayData.fromPrimitiveArray(longArray) + longArray.zipWithIndex.map { case (e, i) => assert(longUnsafeArray.getLong(i) == e) } + + val floatArray = floatData.toArray + val floatUnsafeArray = UnsafeArrayData.fromPrimitiveArray(floatArray) + floatArray.zipWithIndex.map { case (e, i) => assert(floatUnsafeArray.getFloat(i) == e) } + + val doubleArray = doubleData.toArray + val doubleUnsafeArray = UnsafeArrayData.fromPrimitiveArray(doubleArray) + doubleArray.zipWithIndex.map { case (e, i) => assert(doubleUnsafeArray.getDouble(i) == e) } + } + + test("writePrimitiveArray") { + val booleanArray = booleanData.toArray + val booleanUnsafeArray = performUnsafeArrayWriter(booleanArray.length, 4, + (writer: UnsafeArrayWriter) => + writer.writePrimitiveBooleanArray(new GenericArrayData(booleanArray))) + booleanArray.zipWithIndex.map { case (e, i) => assert(booleanUnsafeArray.getBoolean(i) == e) } + + val byteArray = byteData.toArray + val byteUnsafeArray = performUnsafeArrayWriter(byteArray.length, 4, + (writer: UnsafeArrayWriter) => + writer.writePrimitiveByteArray(new GenericArrayData(byteArray))) + byteArray.zipWithIndex.map { case (e, i) => assert(byteUnsafeArray.getByte(i) == e) } + + val shortArray = shortData.toArray + val shortUnsafeArray = performUnsafeArrayWriter(shortArray.length, 4, + (writer: UnsafeArrayWriter) => + writer.writePrimitiveShortArray(new GenericArrayData(shortArray))) + shortArray.zipWithIndex.map { case (e, i) => assert(shortUnsafeArray.getShort(i) == e) } + + val intArray = intData.toArray + val intUnsafeArray = performUnsafeArrayWriter(intArray.length, 4, + (writer: UnsafeArrayWriter) => + writer.writePrimitiveIntArray(new GenericArrayData(intArray))) + intArray.zipWithIndex.map { case (e, i) => assert(intUnsafeArray.getInt(i) == e) } + + val longArray = longData.toArray + val longUnsafeArray = performUnsafeArrayWriter(longArray.length, 8, + (writer: UnsafeArrayWriter) => + writer.writePrimitiveLongArray(new GenericArrayData(longArray))) + longArray.zipWithIndex.map { case (e, i) => assert(longUnsafeArray.getLong(i) == e) } + + val floatArray = floatData.toArray + val floatUnsafeArray = performUnsafeArrayWriter(floatArray.length, 4, + (writer: UnsafeArrayWriter) => + writer.writePrimitiveFloatArray(new GenericArrayData(floatArray))) + floatArray.zipWithIndex.map { case (e, i) => assert(floatUnsafeArray.getFloat(i) == e) } + + val doubleArray = doubleData.toArray + val doubleUnsafeArray = performUnsafeArrayWriter(doubleArray.length, 8, + (writer: UnsafeArrayWriter) => + writer.writePrimitiveDoubleArray(new GenericArrayData(doubleArray))) + doubleArray.zipWithIndex.map { case (e, i) => assert(doubleUnsafeArray.getDouble(i) == e) } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/GenericArrayDataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/GenericArrayDataSuite.scala new file mode 100644 index 000000000000..4c4c5d6b4e9f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/GenericArrayDataSuite.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData + +class GenericArrayDataSuite extends SparkFunSuite { + + test("equals/hash") { + val booleanPrimitiveArray = Array(true, false, true) + val booleanArray1 = new GenericArrayData(booleanPrimitiveArray) + val booleanArray2 = new GenericArrayData(booleanPrimitiveArray) + val anyBooleanArray = new GenericArrayData(booleanPrimitiveArray.toArray[Any]) + assert(booleanArray1.equals(booleanArray2)) + assert(!booleanArray1.equals(anyBooleanArray)) + assert(booleanArray1.hashCode == booleanArray2.hashCode) + assert(booleanArray1.hashCode != anyBooleanArray.hashCode) + + val bytePrimitiveArray = Array(1.toByte, 10.toByte, 100.toByte) + val byteArray1 = new GenericArrayData(bytePrimitiveArray) + val byteArray2 = new GenericArrayData(bytePrimitiveArray) + val anyByteArray = new GenericArrayData(bytePrimitiveArray.toArray[Any]) + assert(byteArray1.equals(byteArray2)) + assert(!byteArray1.equals(anyByteArray)) + assert(byteArray1.hashCode == byteArray2.hashCode) + assert(byteArray1.hashCode != anyByteArray.hashCode) + + val shortPrimitiveArray = Array[Short](1.toShort, 100.toShort, 10000.toShort) + val shortArray1 = new GenericArrayData(shortPrimitiveArray) + val shortArray2 = new GenericArrayData(shortPrimitiveArray) + val anyShortArray = new GenericArrayData(shortPrimitiveArray.toArray[Any]) + assert(shortArray1.equals(shortArray2)) + assert(!shortArray1.equals(anyShortArray)) + assert(shortArray1.hashCode == shortArray2.hashCode) + assert(shortArray1.hashCode != anyShortArray.hashCode) + + val intPrimitiveArray = Array(1, 1000, 1000000) + val intArray1 = new GenericArrayData(intPrimitiveArray) + val intArray2 = new GenericArrayData(intPrimitiveArray) + val anyIntArray = new GenericArrayData(intPrimitiveArray.toArray[Any]) + assert(intArray1.equals(intArray2)) + assert(!intArray1.equals(anyIntArray)) + assert(intArray1.hashCode == intArray2.hashCode) + assert(intArray2.hashCode != anyIntArray.hashCode) + + val longPrimitiveArray = Array(1L, 100000L, 10000000000L) + val longArray1 = new GenericArrayData(longPrimitiveArray) + val longArray2 = new GenericArrayData(longPrimitiveArray) + val anyLongArray = new GenericArrayData(longPrimitiveArray.toArray[Any]) + assert(longArray1.equals(longArray2)) + assert(!longArray1.equals(anyLongArray)) + assert(longArray1.hashCode == longArray2.hashCode) + assert(longArray1.hashCode != anyLongArray.hashCode) + + val floatPrimitiveArray = Array(1.1f, 2.2f, 3.3f) + val floatArray1 = new GenericArrayData(floatPrimitiveArray) + val floatArray2 = new GenericArrayData(floatPrimitiveArray) + val anyFloatArray = new GenericArrayData(floatPrimitiveArray.toArray[Any]) + assert(floatArray1.equals(floatArray2)) + assert(!floatArray1.equals(anyFloatArray)) + assert(floatArray1.hashCode == floatArray2.hashCode) + assert(floatArray1.hashCode != anyFloatArray.hashCode) + + val doublePrimitiveArray = Array(1.1, 2.2, 3.3) + val doubleArray1 = new GenericArrayData(doublePrimitiveArray) + val doubleArray2 = new GenericArrayData(doublePrimitiveArray) + val anyDoubleArray = new GenericArrayData(doublePrimitiveArray.toArray[Any]) + assert(doubleArray1.equals(doubleArray2)) + assert(!doubleArray1.equals(anyDoubleArray)) + assert(doubleArray1.hashCode == doubleArray2.hashCode) + assert(doubleArray1.hashCode != anyDoubleArray.hashCode) + } + + test("from primitive boolean array") { + val primitiveArray = Array(true, false, true) + val array = new GenericArrayData(primitiveArray) + assert(array.isInstanceOf[GenericArrayData]) + assert(array.array == null) + assert(array.booleanArray != null) + assert(array.numElements == primitiveArray.length) + assert(array.isNullAt(0) == false) + assert(array.getBoolean(0) == primitiveArray(0)) + assert(array.getBoolean(1) == primitiveArray(1)) + assert(array.getBoolean(2) == primitiveArray(2)) + assert(array.toBooleanArray()(0) == primitiveArray(0)) + } + + test("from primitive byte array") { + val primitiveArray = Array(1.toByte, 10.toByte, 100.toByte) + val array = new GenericArrayData(primitiveArray) + assert(array.isInstanceOf[GenericArrayData]) + assert(array.array == null) + assert(array.byteArray != null) + assert(array.numElements == primitiveArray.length) + assert(array.isNullAt(0) == false) + assert(array.getByte(0) == primitiveArray(0)) + assert(array.getByte(1) == primitiveArray(1)) + assert(array.getByte(2) == primitiveArray(2)) + assert(array.toByteArray()(0) == primitiveArray(0)) + } + + test("from primitive short array") { + val primitiveArray = Array[Short](1.toShort, 100.toShort, 10000.toShort) + val array = new GenericArrayData(primitiveArray) + assert(array.isInstanceOf[GenericArrayData]) + assert(array.array == null) + assert(array.shortArray != null) + assert(array.numElements == primitiveArray.length) + assert(array.isNullAt(0) == false) + assert(array.getShort(0) == primitiveArray(0)) + assert(array.getShort(1) == primitiveArray(1)) + assert(array.getShort(2) == primitiveArray(2)) + assert(array.toShortArray()(0) == primitiveArray(0)) + } + + test("from primitive int array") { + val primitiveArray = Array(1, 1000, 1000000) + val array = new GenericArrayData(primitiveArray) + assert(array.isInstanceOf[GenericArrayData]) + assert(array.array == null) + assert(array.intArray != null) + assert(array.numElements == primitiveArray.length) + assert(array.isNullAt(0) == false) + assert(array.getInt(0) == primitiveArray(0)) + assert(array.getInt(1) == primitiveArray(1)) + assert(array.getInt(2) == primitiveArray(2)) + assert(array.toIntArray()(0) == primitiveArray(0)) + } + + test("from primitive long array") { + val primitiveArray = Array(1L, 100000L, 10000000000L) + val array = new GenericArrayData(primitiveArray) + assert(array.isInstanceOf[GenericArrayData]) + assert(array.array == null) + assert(array.longArray != null) + assert(array.numElements == primitiveArray.length) + assert(array.isNullAt(0) == false) + assert(array.getLong(0) == primitiveArray(0)) + assert(array.getLong(1) == primitiveArray(1)) + assert(array.getLong(2) == primitiveArray(2)) + assert(array.toLongArray()(0) == primitiveArray(0)) + } + + test("from primitive float array") { + val primitiveArray = Array(1.1f, 2.2f, 3.3f) + val array = new GenericArrayData(primitiveArray) + assert(array.isInstanceOf[GenericArrayData]) + assert(array.array == null) + assert(array.floatArray != null) + assert(array.numElements == primitiveArray.length) + assert(array.isNullAt(0) == false) + assert(array.getFloat(0) == primitiveArray(0)) + assert(array.getFloat(1) == primitiveArray(1)) + assert(array.getFloat(2) == primitiveArray(2)) + assert(array.toFloatArray()(0) == primitiveArray(0)) + } + + test("from primitive double array") { + val primitiveArray = Array(1.1, 2.2, 3.3) + val array = new GenericArrayData(primitiveArray) + assert(array.isInstanceOf[GenericArrayData]) + assert(array.numElements == primitiveArray.length) + assert(array.array == null) + assert(array.doubleArray != null) + assert(array.isNullAt(0) == false) + assert(array.getDouble(0) == primitiveArray(0)) + assert(array.getDouble(1) == primitiveArray(1)) + assert(array.getDouble(2) == primitiveArray(2)) + assert(array.toDoubleArray()(0) == primitiveArray(0)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 1230b921aa27..7d7ded7d1cad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -27,6 +27,19 @@ import org.apache.spark.sql.test.SharedSQLContext class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { import testImplicits._ + test("primitive type and null on array") { + val rows = sparkContext.parallelize(Seq(1, 2), 1).toDF("v"). + selectExpr("Array(v + 2, null, v + 3)") + checkAnswer(rows, Seq(Row(Array(3, null, 4)), Row(Array(4, null, 5)))) + } + + test("array with null on array") { + val rows = sparkContext.parallelize(Seq(1, 2), 1).toDF("v"). + selectExpr("Array(Array(v, v + 1)," + + "null," + + "Array(v, v - 1))").collect + } + test("UDF on struct") { val f = udf((a: String) => a) val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/GenericArrayDataBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/GenericArrayDataBenchmark.scala new file mode 100644 index 000000000000..7303e4ea2f07 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/GenericArrayDataBenchmark.scala @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import scala.concurrent.duration._ + +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.util.Benchmark + +/** + * Benchmark [[GenericArrayData]] for specialized representation with primitive type + * To run this: + * 1. replace ignore(...) with test(...) + * 2. build/sbt "sql/test-only *benchmark.GenericArrayDataBenchmark" + * + * Benchmarks in this file are skipped in normal builds. + */ +class GenericArrayDataBenchmark extends BenchmarkBase { + + def allocateGenericIntArray(iters: Int): Unit = { + val count = 1024 * 1024 + var array: GenericArrayData = null + + val primitiveIntArray = new Array[Int](count) + val specializedIntArray = { i: Int => + var n = 0 + while (n < iters) { + array = new GenericArrayData(primitiveIntArray) + n += 1 + } + } + val anyArray = primitiveIntArray.toArray[Any] + val genericIntArray = { i: Int => + var n = 0 + while (n < iters) { + array = new GenericArrayData(anyArray) + n += 1 + } + } + + val benchmark = new Benchmark("Allocate GenericArrayData for int", count * iters, + minNumIters = 10, minTime = 1.milliseconds) + benchmark.addCase("Generic ")(genericIntArray) + benchmark.addCase("Specialized")(specializedIntArray) + benchmark.run + /* + OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 + Intel Xeon E3-12xx v2 (Ivy Bridge) + Allocate GenericArrayData for int: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Generic 0 / 0 46500044.3 0.0 1.0X + Specialized 0 / 0 170500162.6 0.0 3.7X + */ + } + + def allocateGenericDoubleArray(iters: Int): Unit = { + val count = 1024 * 1024 + var array: GenericArrayData = null + + val primitiveDoubleArray = new Array[Int](count) + val specializedDoubleArray = { i: Int => + var n = 0 + while (n < iters) { + array = new GenericArrayData(primitiveDoubleArray) + n += 1 + } + } + val anyArray = primitiveDoubleArray.toArray[Any] + val genericDoubleArray = { i: Int => + var n = 0 + while (n < iters) { + array = new GenericArrayData(anyArray) + n += 1 + } + } + + val benchmark = new Benchmark("Allocate GenericArrayData for double", count * iters, + minNumIters = 10, minTime = 1.milliseconds) + benchmark.addCase("Generic ")(genericDoubleArray) + benchmark.addCase("Specialized")(specializedDoubleArray) + benchmark.run + /* + OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 + Intel Xeon E3-12xx v2 (Ivy Bridge) + Allocate GenericArrayData for double: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Generic 0 / 0 55627374.0 0.0 1.0X + Specialized 0 / 0 177724745.8 0.0 3.2X + */ + } + + def getPrimitiveIntArray(iters: Int): Unit = { + val count = 1024 * 1024 * 8 + + val anyArray: GenericArrayData = new GenericArrayData(new Array[Int](count).toArray[Any]) + val intArray: GenericArrayData = new GenericArrayData(new Array[Int](count)) + var primitiveIntArray: Array[Int] = null + val genericIntArray = { i: Int => + var n = 0 + while (n < iters) { + primitiveIntArray = anyArray.toIntArray + n += 1 + } + } + val specializedIntArray = { i: Int => + var n = 0 + while (n < iters) { + primitiveIntArray = intArray.toIntArray + n += 1 + } + } + + val benchmark = new Benchmark("Get int primitive array", count * iters) + benchmark.addCase("Generic")(genericIntArray) + benchmark.addCase("Specialized")(specializedIntArray) + benchmark.run + /* + OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 + Intel Xeon E3-12xx v2 (Ivy Bridge) + Get int primitive array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Generic 334 / 382 502.4 2.0 1.0X + Specialized 282 / 314 595.4 1.7 1.2X + */ + } + + def getPrimitiveDoubleArray(iters: Int): Unit = { + val count = 1024 * 1024 * 8 + + val anyArray: GenericArrayData = new GenericArrayData(new Array[Double](count).toArray[Any]) + val doubleArray: GenericArrayData = new GenericArrayData(new Array[Double](count)) + var primitiveDoubleArray: Array[Double] = null + val genericDoubleArray = { i: Int => + var n = 0 + while (n < iters) { + primitiveDoubleArray = anyArray.toDoubleArray + n += 1 + } + } + val specializedDoubleArray = { i: Int => + var n = 0 + while (n < iters) { + primitiveDoubleArray = doubleArray.toDoubleArray + n += 1 + } + } + + val benchmark = new Benchmark("Get double primitive array", count * iters) + benchmark.addCase("Generic")(genericDoubleArray) + benchmark.addCase("Specialized")(specializedDoubleArray) + benchmark.run + /* + OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 + Intel Xeon E3-12xx v2 (Ivy Bridge) + Get double primitive array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Generic 1720 / 1883 97.6 10.3 1.0X + Specialized 703 / 1117 238.7 4.2 2.4X + */ + } + + def readGenericIntArray(iters: Int): Unit = { + val count = 1024 * 1024 * 8 + var result: Int = 0 + + val anyArray = new GenericArrayData(new Array[Int](count).toArray[Any]) + val genericIntArray = { i: Int => + var n = 0 + while (n < iters) { + val len = anyArray.numElements + var sum = 0 + var i = 0 + while (i < len) { + sum += anyArray.getInt(i) + i += 1 + } + result = sum + n += 1 + } + } + + val intArray = new GenericArrayData(new Array[Int](count)) + val specializedIntArray = { i: Int => + var n = 0 + while (n < iters) { + val len = intArray.numElements + var sum = 0 + var i = 0 + while (i < len) { + sum += intArray.getInt(i) + i += 1 + } + result = sum + n += 1 + } + } + + val benchmark = new Benchmark("Read GenericArrayData Int", count * iters) + benchmark.addCase("Generic")(genericIntArray) + benchmark.addCase("Specialized")(specializedIntArray) + benchmark.run + /* + OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 + Intel Xeon E3-12xx v2 (Ivy Bridge) + Read GenericArrayData Int: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Generic 206 / 212 1017.6 1.0 1.0X + Specialized 161 / 167 1301.0 0.8 1.3X + */ + } + + def readGenericDoubleArray(iters: Int): Unit = { + val count = 1024 * 1024 * 8 + var result: Double = 0 + + val anyArray = new GenericArrayData(new Array[Double](count).toArray[Any]) + val genericDoubleArray = { i: Int => + var n = 0 + while (n < iters) { + val len = anyArray.numElements + var sum = 0.toDouble + var i = 0 + while (i < len) { + sum += anyArray.getDouble(i) + i += 1 + } + result = sum + n += 1 + } + } + + val doubleArray = new GenericArrayData(new Array[Double](count)) + val specializedDoubleArray = { i: Int => + var n = 0 + while (n < iters) { + val len = doubleArray.numElements + var sum = 0.toDouble + var i = 0 + while (i < len) { + sum += doubleArray.getDouble(i) + i += 1 + } + result = sum + n += 1 + } + } + + val benchmark = new Benchmark("Read GenericArrayData Double", count * iters) + benchmark.addCase("Generic")(genericDoubleArray) + benchmark.addCase("Specialized")(specializedDoubleArray) + benchmark.run + /* + OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 + Intel Xeon E3-12xx v2 (Ivy Bridge) + Read GenericArrayData Double: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Generic 547 / 581 383.3 2.6 1.0X + Specialized 237 / 260 884.0 1.1 2.3X + */ + } + + ignore("allocate GenericArrayData") { + allocateGenericIntArray(20) + allocateGenericDoubleArray(20) + } + + ignore("get primitive array") { + getPrimitiveIntArray(20) + getPrimitiveDoubleArray(20) + } + + ignore("read elements in GenericArrayData") { + readGenericIntArray(25) + readGenericDoubleArray(25) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala index e7c8f2717fd7..5eba0de1215d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/PrimitiveArrayBenchmark.scala @@ -79,4 +79,44 @@ class PrimitiveArrayBenchmark extends BenchmarkBase { ignore("Write an array in Dataset") { writeDatasetArray(4) } + + def writeArray(iters: Int): Unit = { + import sparkSession.implicits._ + + val iters = 5 + val n = 1024 * 1024 + val rows = 15 + + val benchmark = new Benchmark("Write an array in Dataframe", n) + + val intDF = sparkSession.sparkContext.parallelize(0 until rows, 1) + .map(i => Array.tabulate(n)(i => i)).toDF() + intDF.count() // force to create df + + benchmark.addCase(s"Write int array in DataFrame", numIters = iters)(iter => { + intDF.selectExpr("value as a").collect + }) + + val doubleDF = sparkSession.sparkContext.parallelize(0 until rows, 1) + .map(i => Array.tabulate(n)(i => i.toDouble)).toDF() + doubleDF.count() // force to create df + + benchmark.addCase(s"Write double array in DataFrame", numIters = iters)(iter => { + doubleDF.selectExpr("value as a").collect + }) + + benchmark.run() + /* + OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 + Intel Xeon E3-12xx v2 (Ivy Bridge) + Read primitive array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + Write int array in DataFrame 1290 / 1748 0.8 1230.1 1.0X + Write double array in DataFrame 1761 / 2236 0.6 1679.0 0.7X + */ + } + + ignore("Write an array in DataFrame") { + writeArray(1) + } }