diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index ad267ab0c9c4..2c05dc89ed52 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -25,6 +25,7 @@ import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.internal.SQLConf; @@ -82,6 +83,15 @@ public static final class Array extends ArrayData { public int length; public int offset; + // reused buffer to return a primitive array + protected boolean[] reuseBooleanArray; + protected byte[] reuseByteArray; + protected short[] reuseShortArray; + protected int[] reuseIntArray; + protected long[] reuseLongArray; + protected float[] reuseFloatArray; + protected double[] reuseDoubleArray; + // Populate if binary data is required for the Array. This is stored here as an optimization // for string data. public byte[] byteArray; @@ -102,6 +112,69 @@ public ArrayData copy() { throw new UnsupportedOperationException(); } + @Override + public boolean[] toBooleanArray() { + if (reuseBooleanArray == null || reuseBooleanArray.length != length) { + reuseBooleanArray = new boolean[length]; + } + data.getBooleanArray(offset, length, reuseBooleanArray); + return reuseBooleanArray; + } + + @Override + public byte[] toByteArray() { + if (reuseByteArray == null || reuseByteArray.length != length) { + reuseByteArray = new byte[length]; + } + data.getByteArray(offset, length, reuseByteArray); + return reuseByteArray; + } + + @Override + public short[] toShortArray() { + if (reuseShortArray == null || reuseShortArray.length != length) { + reuseShortArray = new short[length]; + } + data.getShortArray(offset, length, reuseShortArray); + return reuseShortArray; + } + + @Override + public int[] toIntArray() { + if (reuseIntArray == null || reuseIntArray.length != length) { + reuseIntArray = new int[length]; + } + data.getIntArray(offset, length, reuseIntArray); + return reuseIntArray; + } + + @Override + public long[] toLongArray() { + if (reuseLongArray == null || reuseLongArray.length != length) { + reuseLongArray = new long[length]; + } + data.getLongArray(offset, length, reuseLongArray); + return reuseLongArray; + } + + @Override + public float[] toFloatArray() { + if (reuseFloatArray == null || reuseFloatArray.length != length) { + reuseFloatArray = new float[length]; + } + data.getFloatArray(offset, length, reuseFloatArray); + return reuseFloatArray; + } + + @Override + public double[] toDoubleArray() { + if (reuseDoubleArray == null || reuseDoubleArray.length != length) { + reuseDoubleArray = new double[length]; + } + data.getDoubleArray(offset, length, reuseDoubleArray); + return reuseDoubleArray; + } + // TODO: this is extremely expensive. @Override public Object[] array() { @@ -368,6 +441,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract boolean getBoolean(int rowId); + /** + * Sets a primitive array for (offset, length) to array. + */ + public abstract void getBooleanArray(int rowId, int count, boolean[] array); + /** * Sets the value at rowId to `value`. */ @@ -388,6 +466,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract byte getByte(int rowId); + /** + * Sets a primitive array for (offset, length) to array. + */ + public abstract void getByteArray(int rowId, int count, byte[] array); + /** * Sets the value at rowId to `value`. */ @@ -403,6 +486,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract void putShorts(int rowId, int count, short[] src, int srcIndex); + /** + * Sets a primitive array for (offset, length) to array. + */ + public abstract void getShortArray(int rowId, int count, short[] array); + /** * Returns the value for rowId. */ @@ -434,6 +522,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract int getInt(int rowId); + /** + * Sets a primitive array for (offset, length) to array. + */ + public abstract void getIntArray(int rowId, int count, int[] array); + /** * Returns the dictionary Id for rowId. * This should only be called when the ColumnVector is dictionaryIds. @@ -467,6 +560,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract long getLong(int rowId); + /** + * Sets a primitive array for (offset, length) to array. + */ + public abstract void getLongArray(int rowId, int count, long[] array); + /** * Sets the value at rowId to `value`. */ @@ -494,6 +592,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract float getFloat(int rowId); + /** + * Sets a primitive array for (offset, length) to array. + */ + public abstract void getFloatArray(int rowId, int count, float[] array); + /** * Sets the value at rowId to `value`. */ @@ -521,6 +624,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract double getDouble(int rowId); + /** + * Sets a primitive array for (offset, length) to array. + */ + public abstract void getDoubleArray(int rowId, int count, double[] array); + /** * Puts a byte array that already exists in this column. */ @@ -562,6 +670,26 @@ public final Array getArray(int rowId) { return resultArray; } + public final int putArray(int rowId, ArrayData array) { + UnsafeArrayData unsafeArray = (UnsafeArrayData)array; + Object baseObjects = unsafeArray.getBaseObject(); + int length = unsafeArray.getSizeInBytes(); + int numElements = unsafeArray.numElements(); + long elementOffset = unsafeArray.getBaseOffset() + UnsafeArrayData.calculateHeaderPortionInBytes(numElements); + childColumns[0].putArray(rowId, baseObjects, (int) elementOffset, elementsAppended, numElements); + putArray(rowId, elementsAppended, numElements); + elementsAppended += numElements; + + if (((ArrayType)type).containsNull()) { + for (int i = 0; i < numElements; i++) { + if (unsafeArray.isNullAt(i)) { + childColumns[0].putNotNull(i); + } + } + } + return length; + } + /** * Loads the data into array.byteArray. */ @@ -570,6 +698,7 @@ public final Array getArray(int rowId) { /** * Sets the value at rowId to `value`. */ + public abstract void putArray(int rowId, Object value, int srcOffset, int dstOffset, int numElements); public abstract int putByteArray(int rowId, byte[] value, int offset, int count); public final int putByteArray(int rowId, byte[] value) { return putByteArray(rowId, value, 0, value.length); @@ -930,7 +1059,7 @@ public final int appendStruct(boolean isNull) { protected static final int DEFAULT_ARRAY_LENGTH = 4; /** - * Current write cursor (row index) when appending data. + * Current write cursor (row index) when appending or putting data. */ protected int elementsAppended; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index a7d3744d00e9..02cc415561cc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -134,6 +134,9 @@ public void putBooleans(int rowId, int count, boolean value) { @Override public boolean getBoolean(int rowId) { return Platform.getByte(null, data + rowId) == 1; } + @Override + public void getBooleanArray(int offset, int length, boolean[] array) { throw new UnsupportedOperationException(); } + // // APIs dealing with Bytes // @@ -165,6 +168,9 @@ public byte getByte(int rowId) { } } + @Override + public void getByteArray(int offset, int length, byte[] array) { throw new UnsupportedOperationException(); } + // // APIs dealing with shorts // @@ -197,6 +203,9 @@ public short getShort(int rowId) { } } + @Override + public void getShortArray(int offset, int length, short[] array) { throw new UnsupportedOperationException(); } + // // APIs dealing with ints // @@ -255,6 +264,9 @@ public int getDictId(int rowId) { return Platform.getInt(null, data + 4 * rowId); } + @Override + public void getIntArray(int offset, int length, int[] array) { throw new UnsupportedOperationException(); } + // // APIs dealing with Longs // @@ -302,6 +314,9 @@ public long getLong(int rowId) { } } + @Override + public void getLongArray(int offset, int length, long[] array) { throw new UnsupportedOperationException(); } + // // APIs dealing with floats // @@ -348,6 +363,8 @@ public float getFloat(int rowId) { } } + @Override + public void getFloatArray(int offset, int length, float[] array) { throw new UnsupportedOperationException(); } // // APIs dealing with doubles @@ -395,6 +412,9 @@ public double getDouble(int rowId) { } } + @Override + public void getDoubleArray(int offset, int length, double[] array) { throw new UnsupportedOperationException(); } + // // APIs dealing with Arrays. // @@ -405,6 +425,11 @@ public void putArray(int rowId, int offset, int length) { Platform.putInt(null, offsetData + 4 * rowId, offset); } + @Override + public void putArray(int rowId, Object src, int srcOffset, int dstOffset, int length) { + throw new UnsupportedOperationException(); + } + @Override public int getArrayLength(int rowId) { return Platform.getInt(null, lengthData + 4 * rowId); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 94ed32294cfa..4979a18beb71 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -130,6 +130,12 @@ public boolean getBoolean(int rowId) { return byteData[rowId] == 1; } + @Override + public void getBooleanArray(int offset, int length, boolean[] array) { + // assume that it is possible to do bulkcopy from byte[] to boolean[] + Platform.copyMemory(byteData, Platform.BYTE_ARRAY_OFFSET + offset, array, Platform.BOOLEAN_ARRAY_OFFSET, length); + } + // // @@ -162,6 +168,11 @@ public byte getByte(int rowId) { } } + @Override + public void getByteArray(int offset, int length, byte[] array) { + Platform.copyMemory(byteData, Platform.BYTE_ARRAY_OFFSET + offset, array, Platform.BYTE_ARRAY_OFFSET, length); + } + // // APIs dealing with Shorts // @@ -192,6 +203,10 @@ public short getShort(int rowId) { } } + @Override + public void getShortArray(int offset, int length, short[] array) { + Platform.copyMemory(shortData, Platform.SHORT_ARRAY_OFFSET + offset * 2, array, Platform.SHORT_ARRAY_OFFSET, length * 2); + } // // APIs dealing with Ints @@ -234,6 +249,12 @@ public int getInt(int rowId) { } } + @Override + public void getIntArray(int offset, int length, int[] array) { + assert(dictionary == null); + Platform.copyMemory(intData, Platform.INT_ARRAY_OFFSET + offset * 4, array, Platform.INT_ARRAY_OFFSET, length * 4); + } + /** * Returns the dictionary Id for rowId. * This should only be called when the ColumnVector is dictionaryIds. @@ -286,6 +307,12 @@ public long getLong(int rowId) { } } + @Override + public void getLongArray(int offset, int length, long[] array) { + assert(dictionary == null); + Platform.copyMemory(longData, Platform.LONG_ARRAY_OFFSET + offset * 8, array, Platform.LONG_ARRAY_OFFSET, length * 8); + } + // // APIs dealing with floats // @@ -325,6 +352,12 @@ public float getFloat(int rowId) { } } + @Override + public void getFloatArray(int offset, int length, float[] array) { + assert(dictionary == null); + Platform.copyMemory(floatData, Platform.FLOAT_ARRAY_OFFSET + offset * 4, array, Platform.FLOAT_ARRAY_OFFSET, length * 4); + } + // // APIs dealing with doubles // @@ -366,6 +399,12 @@ public double getDouble(int rowId) { } } + @Override + public void getDoubleArray(int offset, int length, double[] array) { + assert(dictionary == null); + Platform.copyMemory(doubleData, Platform.DOUBLE_ARRAY_OFFSET + offset * 8, array, Platform.DOUBLE_ARRAY_OFFSET, length * 8); + } + // // APIs dealing with Arrays // @@ -385,6 +424,35 @@ public void putArray(int rowId, int offset, int length) { arrayLengths[rowId] = length; } + @Override + public void putArray(int rowId, Object src, int srcOffset, int dstOffset, int numElements) { + DataType et = type; + reserve(dstOffset + numElements); + if (et == DataTypes.BooleanType || et == DataTypes.ByteType) { + Platform.copyMemory( + src, srcOffset, byteData, Platform.BYTE_ARRAY_OFFSET + dstOffset, numElements); + } else if (et == DataTypes.BooleanType || et == DataTypes.ByteType) { + Platform.copyMemory( + src, srcOffset, shortData, Platform.SHORT_ARRAY_OFFSET + dstOffset * 2, numElements * 2); + } else if (et == DataTypes.IntegerType || et == DataTypes.DateType || + DecimalType.is32BitDecimalType(type)) { + Platform.copyMemory( + src, srcOffset, intData, Platform.INT_ARRAY_OFFSET + dstOffset * 4, numElements * 4); + } else if (type instanceof LongType || type instanceof TimestampType || + DecimalType.is64BitDecimalType(type)) { + Platform.copyMemory( + src, srcOffset, longData, Platform.LONG_ARRAY_OFFSET + dstOffset * 8, numElements * 8); + } else if (et == DataTypes.FloatType) { + Platform.copyMemory( + src, srcOffset, floatData, Platform.FLOAT_ARRAY_OFFSET + dstOffset * 4, numElements * 4); + } else if (et == DataTypes.DoubleType) { + Platform.copyMemory( + src, srcOffset, doubleData, Platform.DOUBLE_ARRAY_OFFSET + dstOffset * 8, numElements * 8); + } else { + throw new RuntimeException("Unhandled " + type); + } + } + @Override public void loadBytes(ColumnVector.Array array) { array.byteArray = byteData; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index e48e3f640290..963352920777 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types.CalendarInterval @@ -672,26 +673,30 @@ class ColumnarBatchSuite extends SparkFunSuite { column.putArray(2, 2, 0) column.putArray(3, 3, 3) - val a1 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]] - val a2 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(1)).asInstanceOf[Array[Int]] - val a3 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(2)).asInstanceOf[Array[Int]] - val a4 = ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(3)).asInstanceOf[Array[Int]] + val a1 = ColumnVectorUtils.toPrimitiveJavaArray( + column.getArray(0).asInstanceOf[ColumnVector.Array]).asInstanceOf[Array[Int]] + val a2 = ColumnVectorUtils.toPrimitiveJavaArray( + column.getArray(1).asInstanceOf[ColumnVector.Array]).asInstanceOf[Array[Int]] + val a3 = ColumnVectorUtils.toPrimitiveJavaArray( + column.getArray(2).asInstanceOf[ColumnVector.Array]).asInstanceOf[Array[Int]] + val a4 = ColumnVectorUtils.toPrimitiveJavaArray( + column.getArray(3).asInstanceOf[ColumnVector.Array]).asInstanceOf[Array[Int]] assert(a1 === Array(0)) assert(a2 === Array(1, 2)) assert(a3 === Array.empty[Int]) assert(a4 === Array(3, 4, 5)) // Verify the ArrayData APIs - assert(column.getArray(0).length == 1) + assert(column.getArray(0).numElements() == 1) assert(column.getArray(0).getInt(0) == 0) - assert(column.getArray(1).length == 2) + assert(column.getArray(1).numElements() == 2) assert(column.getArray(1).getInt(0) == 1) assert(column.getArray(1).getInt(1) == 2) - assert(column.getArray(2).length == 0) + assert(column.getArray(2).numElements() == 0) - assert(column.getArray(3).length == 3) + assert(column.getArray(3).numElements() == 3) assert(column.getArray(3).getInt(0) == 3) assert(column.getArray(3).getInt(1) == 4) assert(column.getArray(3).getInt(2) == 5) @@ -704,8 +709,55 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(data.capacity == array.length * 2) data.putInts(0, array.length, array, 0) column.putArray(0, 0, array.length) - assert(ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]] - === array) + assert(ColumnVectorUtils.toPrimitiveJavaArray( + column.getArray(0).asInstanceOf[ColumnVector.Array]).asInstanceOf[Array[Int]] === array) + }} + } + + test("Int UnsafeArray") { + (MemoryMode.ON_HEAP :: Nil).foreach { memMode => { + val column = ColumnVector.allocate(10, new ArrayType(IntegerType, true), memMode) + + // Populate it with arrays [0], [1, 2], [], [3, 4, 5] + val len1 = column.putArray(0, UnsafeArrayData.fromPrimitiveArray(Array(0))) + val len2 = column.putArray(1, UnsafeArrayData.fromPrimitiveArray(Array(1, 2))) + val len3 = column.putArray(2, UnsafeArrayData.fromPrimitiveArray(Array.empty[Int])) + val len4 = column.putArray(3, UnsafeArrayData.fromPrimitiveArray(Array(3, 4, 5))) + // since UnsafeArrayData.fromPrimitiveArray allocates long[], size should be ceiled by 8 + assert(len1 == ((UnsafeArrayData.calculateHeaderPortionInBytes(1) + 1 * 4 + 7) / 8) * 8) + assert(len2 == ((UnsafeArrayData.calculateHeaderPortionInBytes(2) + 2 * 4 + 7) / 8) * 8) + assert(len3 == ((UnsafeArrayData.calculateHeaderPortionInBytes(0) + 0 * 4 + 7) / 8) * 8) + assert(len4 == ((UnsafeArrayData.calculateHeaderPortionInBytes(3) + 3 * 4 + 7) / 8) * 8) + + val a1 = column.getArray(0).toIntArray + val a2 = column.getArray(1).toIntArray + val a3 = column.getArray(2).toIntArray + val a4 = column.getArray(3).toIntArray + assert(a1 === Array(0)) + assert(a2 === Array(1, 2)) + assert(a3 === Array.empty[Int]) + assert(a4 === Array(3, 4, 5)) + + // Verify the ArrayData APIs + assert(column.getArray(0).numElements() == 1) + assert(column.getArray(0).getInt(0) == 0) + + assert(column.getArray(1).numElements() == 2) + assert(column.getArray(1).getInt(0) == 1) + assert(column.getArray(1).getInt(1) == 2) + + assert(column.getArray(2).numElements() == 0) + + assert(column.getArray(3).numElements() == 3) + assert(column.getArray(3).getInt(0) == 3) + assert(column.getArray(3).getInt(1) == 4) + assert(column.getArray(3).getInt(2) == 5) + + // Add a longer array which requires resizing + column.reset + val array = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20) + column.putArray(0, UnsafeArrayData.fromPrimitiveArray(array)) + assert(column.getArray(0).toIntArray === array) }} }