diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index a014e2aa34820..6d159a6c9ed4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -450,14 +450,13 @@ class CodegenContext { /** * Returns the specialized code to set a given value in a column vector for a given `DataType`. */ - def setValue(batch: String, row: String, dataType: DataType, ordinal: Int, - value: String): String = { + def setValue(vector: String, rowId: String, dataType: DataType, value: String): String = { val jt = javaType(dataType) dataType match { case _ if isPrimitiveType(jt) => - s"$batch.column($ordinal).put${primitiveTypeName(jt)}($row, $value);" - case t: DecimalType => s"$batch.column($ordinal).putDecimal($row, $value, ${t.precision});" - case t: StringType => s"$batch.column($ordinal).putByteArray($row, $value.getBytes());" + s"$vector.put${primitiveTypeName(jt)}($rowId, $value);" + case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});" + case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());" case _ => throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType") } @@ -468,37 +467,36 @@ class CodegenContext { * that could potentially be nullable. */ def updateColumn( - batch: String, - row: String, + vector: String, + rowId: String, dataType: DataType, - ordinal: Int, ev: ExprCode, nullable: Boolean): String = { if (nullable) { s""" if (!${ev.isNull}) { - ${setValue(batch, row, dataType, ordinal, ev.value)} + ${setValue(vector, rowId, dataType, ev.value)} } else { - $batch.column($ordinal).putNull($row); + $vector.putNull($rowId); } """ } else { - s"""${setValue(batch, row, dataType, ordinal, ev.value)};""" + s"""${setValue(vector, rowId, dataType, ev.value)};""" } } /** * Returns the specialized code to access a value from a column vector for a given `DataType`. */ - def getValue(batch: String, row: String, dataType: DataType, ordinal: Int): String = { + def getValue(vector: String, rowId: String, dataType: DataType): String = { val jt = javaType(dataType) dataType match { case _ if isPrimitiveType(jt) => - s"$batch.column($ordinal).get${primitiveTypeName(jt)}($row)" + s"$vector.get${primitiveTypeName(jt)}($rowId)" case t: DecimalType => - s"$batch.column($ordinal).getDecimal($row, ${t.precision}, ${t.scale})" + s"$vector.getDecimal($rowId, ${t.precision}, ${t.scale})" case StringType => - s"$batch.column($ordinal).getUTF8String($row)" + s"$vector.getUTF8String($rowId)" case _ => throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType") } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index fd8db1727212f..f37864a0f5393 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils; import org.apache.spark.sql.execution.vectorized.ColumnVector; +import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DecimalType; @@ -135,9 +136,9 @@ private boolean next() throws IOException { /** * Reads `total` values from this columnReader into column. */ - void readBatch(int total, ColumnVector column) throws IOException { + void readBatch(int total, WritableColumnVector column) throws IOException { int rowId = 0; - ColumnVector dictionaryIds = null; + WritableColumnVector dictionaryIds = null; if (dictionary != null) { // SPARK-16334: We only maintain a single dictionary per row batch, so that it can be used to // decode all previous dictionary encoded pages if we ever encounter a non-dictionary encoded @@ -219,8 +220,11 @@ void readBatch(int total, ColumnVector column) throws IOException { /** * Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`. */ - private void decodeDictionaryIds(int rowId, int num, ColumnVector column, - ColumnVector dictionaryIds) { + private void decodeDictionaryIds( + int rowId, + int num, + WritableColumnVector column, + ColumnVector dictionaryIds) { switch (descriptor.getType()) { case INT32: if (column.dataType() == DataTypes.IntegerType || @@ -346,13 +350,13 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column, * is guaranteed that num is smaller than the number of values left in the current page. */ - private void readBooleanBatch(int rowId, int num, ColumnVector column) throws IOException { + private void readBooleanBatch(int rowId, int num, WritableColumnVector column) throws IOException { assert(column.dataType() == DataTypes.BooleanType); defColumn.readBooleans( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } - private void readIntBatch(int rowId, int num, ColumnVector column) throws IOException { + private void readIntBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType || @@ -370,7 +374,7 @@ private void readIntBatch(int rowId, int num, ColumnVector column) throws IOExce } } - private void readLongBatch(int rowId, int num, ColumnVector column) throws IOException { + private void readLongBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. if (column.dataType() == DataTypes.LongType || DecimalType.is64BitDecimalType(column.dataType())) { @@ -389,7 +393,7 @@ private void readLongBatch(int rowId, int num, ColumnVector column) throws IOExc } } - private void readFloatBatch(int rowId, int num, ColumnVector column) throws IOException { + private void readFloatBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: support implicit cast to double? if (column.dataType() == DataTypes.FloatType) { @@ -400,7 +404,7 @@ private void readFloatBatch(int rowId, int num, ColumnVector column) throws IOEx } } - private void readDoubleBatch(int rowId, int num, ColumnVector column) throws IOException { + private void readDoubleBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions if (column.dataType() == DataTypes.DoubleType) { @@ -411,7 +415,7 @@ private void readDoubleBatch(int rowId, int num, ColumnVector column) throws IOE } } - private void readBinaryBatch(int rowId, int num, ColumnVector column) throws IOException { + private void readBinaryBatch(int rowId, int num, WritableColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions VectorizedValuesReader data = (VectorizedValuesReader) dataColumn; @@ -432,8 +436,11 @@ private void readBinaryBatch(int rowId, int num, ColumnVector column) throws IOE } } - private void readFixedLenByteArrayBatch(int rowId, int num, - ColumnVector column, int arrayLen) throws IOException { + private void readFixedLenByteArrayBatch( + int rowId, + int num, + WritableColumnVector column, + int arrayLen) throws IOException { VectorizedValuesReader data = (VectorizedValuesReader) dataColumn; // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 04f8141d66e9d..0cacf0c9c93a5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -31,6 +31,9 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; import org.apache.spark.sql.execution.vectorized.ColumnarBatch; +import org.apache.spark.sql.execution.vectorized.WritableColumnVector; +import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector; +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -90,6 +93,8 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa */ private ColumnarBatch columnarBatch; + private WritableColumnVector[] columnVectors; + /** * If true, this class returns batches instead of rows. */ @@ -172,20 +177,26 @@ public void initBatch(MemoryMode memMode, StructType partitionColumns, } } - columnarBatch = ColumnarBatch.allocate(batchSchema, memMode); + int capacity = ColumnarBatch.DEFAULT_BATCH_SIZE; + if (memMode == MemoryMode.OFF_HEAP) { + columnVectors = OffHeapColumnVector.allocateColumns(capacity, batchSchema); + } else { + columnVectors = OnHeapColumnVector.allocateColumns(capacity, batchSchema); + } + columnarBatch = new ColumnarBatch(batchSchema, columnVectors, capacity); if (partitionColumns != null) { int partitionIdx = sparkSchema.fields().length; for (int i = 0; i < partitionColumns.fields().length; i++) { - ColumnVectorUtils.populate(columnarBatch.column(i + partitionIdx), partitionValues, i); - columnarBatch.column(i + partitionIdx).setIsConstant(); + ColumnVectorUtils.populate(columnVectors[i + partitionIdx], partitionValues, i); + columnVectors[i + partitionIdx].setIsConstant(); } } // Initialize missing columns with nulls. for (int i = 0; i < missingColumns.length; i++) { if (missingColumns[i]) { - columnarBatch.column(i).putNulls(0, columnarBatch.capacity()); - columnarBatch.column(i).setIsConstant(); + columnVectors[i].putNulls(0, columnarBatch.capacity()); + columnVectors[i].setIsConstant(); } } } @@ -226,7 +237,7 @@ public boolean nextBatch() throws IOException { int num = (int) Math.min((long) columnarBatch.capacity(), totalCountLoadedSoFar - rowsReturned); for (int i = 0; i < columnReaders.length; ++i) { if (columnReaders[i] == null) continue; - columnReaders[i].readBatch(num, columnarBatch.column(i)); + columnReaders[i].readBatch(num, columnVectors[i]); } rowsReturned += num; columnarBatch.setNumRows(num); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index 98018b7f48bd8..5b75f719339fb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -20,7 +20,7 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; -import org.apache.spark.sql.execution.vectorized.ColumnVector; +import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.spark.unsafe.Platform; import org.apache.parquet.column.values.ValuesReader; @@ -56,7 +56,7 @@ public void skip() { } @Override - public final void readBooleans(int total, ColumnVector c, int rowId) { + public final void readBooleans(int total, WritableColumnVector c, int rowId) { // TODO: properly vectorize this for (int i = 0; i < total; i++) { c.putBoolean(rowId + i, readBoolean()); @@ -64,31 +64,31 @@ public final void readBooleans(int total, ColumnVector c, int rowId) { } @Override - public final void readIntegers(int total, ColumnVector c, int rowId) { + public final void readIntegers(int total, WritableColumnVector c, int rowId) { c.putIntsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); offset += 4 * total; } @Override - public final void readLongs(int total, ColumnVector c, int rowId) { + public final void readLongs(int total, WritableColumnVector c, int rowId) { c.putLongsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); offset += 8 * total; } @Override - public final void readFloats(int total, ColumnVector c, int rowId) { + public final void readFloats(int total, WritableColumnVector c, int rowId) { c.putFloats(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); offset += 4 * total; } @Override - public final void readDoubles(int total, ColumnVector c, int rowId) { + public final void readDoubles(int total, WritableColumnVector c, int rowId) { c.putDoubles(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET); offset += 8 * total; } @Override - public final void readBytes(int total, ColumnVector c, int rowId) { + public final void readBytes(int total, WritableColumnVector c, int rowId) { for (int i = 0; i < total; i++) { // Bytes are stored as a 4-byte little endian int. Just read the first byte. // TODO: consider pushing this in ColumnVector by adding a readBytes with a stride. @@ -159,7 +159,7 @@ public final double readDouble() { } @Override - public final void readBinary(int total, ColumnVector v, int rowId) { + public final void readBinary(int total, WritableColumnVector v, int rowId) { for (int i = 0; i < total; i++) { int len = readInteger(); int start = offset; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index 62157389013bb..fc7fa70c39419 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -25,7 +25,7 @@ import org.apache.parquet.io.ParquetDecodingException; import org.apache.parquet.io.api.Binary; -import org.apache.spark.sql.execution.vectorized.ColumnVector; +import org.apache.spark.sql.execution.vectorized.WritableColumnVector; /** * A values reader for Parquet's run-length encoded data. This is based off of the version in @@ -177,7 +177,11 @@ public int readInteger() { * c[rowId] = null; * } */ - public void readIntegers(int total, ColumnVector c, int rowId, int level, + public void readIntegers( + int total, + WritableColumnVector c, + int rowId, + int level, VectorizedValuesReader data) { int left = total; while (left > 0) { @@ -208,8 +212,12 @@ public void readIntegers(int total, ColumnVector c, int rowId, int level, } // TODO: can this code duplication be removed without a perf penalty? - public void readBooleans(int total, ColumnVector c, - int rowId, int level, VectorizedValuesReader data) { + public void readBooleans( + int total, + WritableColumnVector c, + int rowId, + int level, + VectorizedValuesReader data) { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -238,8 +246,12 @@ public void readBooleans(int total, ColumnVector c, } } - public void readBytes(int total, ColumnVector c, - int rowId, int level, VectorizedValuesReader data) { + public void readBytes( + int total, + WritableColumnVector c, + int rowId, + int level, + VectorizedValuesReader data) { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -268,8 +280,12 @@ public void readBytes(int total, ColumnVector c, } } - public void readShorts(int total, ColumnVector c, - int rowId, int level, VectorizedValuesReader data) { + public void readShorts( + int total, + WritableColumnVector c, + int rowId, + int level, + VectorizedValuesReader data) { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -300,8 +316,12 @@ public void readShorts(int total, ColumnVector c, } } - public void readLongs(int total, ColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + public void readLongs( + int total, + WritableColumnVector c, + int rowId, + int level, + VectorizedValuesReader data) { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -330,8 +350,12 @@ public void readLongs(int total, ColumnVector c, int rowId, int level, } } - public void readFloats(int total, ColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + public void readFloats( + int total, + WritableColumnVector c, + int rowId, + int level, + VectorizedValuesReader data) { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -360,8 +384,12 @@ public void readFloats(int total, ColumnVector c, int rowId, int level, } } - public void readDoubles(int total, ColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + public void readDoubles( + int total, + WritableColumnVector c, + int rowId, + int level, + VectorizedValuesReader data) { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -390,8 +418,12 @@ public void readDoubles(int total, ColumnVector c, int rowId, int level, } } - public void readBinarys(int total, ColumnVector c, int rowId, int level, - VectorizedValuesReader data) { + public void readBinarys( + int total, + WritableColumnVector c, + int rowId, + int level, + VectorizedValuesReader data) { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -424,8 +456,13 @@ public void readBinarys(int total, ColumnVector c, int rowId, int level, * Decoding for dictionary ids. The IDs are populated into `values` and the nullability is * populated into `nulls`. */ - public void readIntegers(int total, ColumnVector values, ColumnVector nulls, int rowId, int level, - VectorizedValuesReader data) { + public void readIntegers( + int total, + WritableColumnVector values, + WritableColumnVector nulls, + int rowId, + int level, + VectorizedValuesReader data) { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -459,7 +496,7 @@ public void readIntegers(int total, ColumnVector values, ColumnVector nulls, int // IDs. This is different than the above APIs that decodes definitions levels along with values. // Since this is only used to decode dictionary IDs, only decoding integers is supported. @Override - public void readIntegers(int total, ColumnVector c, int rowId) { + public void readIntegers(int total, WritableColumnVector c, int rowId) { int left = total; while (left > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -485,32 +522,32 @@ public byte readByte() { } @Override - public void readBytes(int total, ColumnVector c, int rowId) { + public void readBytes(int total, WritableColumnVector c, int rowId) { throw new UnsupportedOperationException("only readInts is valid."); } @Override - public void readLongs(int total, ColumnVector c, int rowId) { + public void readLongs(int total, WritableColumnVector c, int rowId) { throw new UnsupportedOperationException("only readInts is valid."); } @Override - public void readBinary(int total, ColumnVector c, int rowId) { + public void readBinary(int total, WritableColumnVector c, int rowId) { throw new UnsupportedOperationException("only readInts is valid."); } @Override - public void readBooleans(int total, ColumnVector c, int rowId) { + public void readBooleans(int total, WritableColumnVector c, int rowId) { throw new UnsupportedOperationException("only readInts is valid."); } @Override - public void readFloats(int total, ColumnVector c, int rowId) { + public void readFloats(int total, WritableColumnVector c, int rowId) { throw new UnsupportedOperationException("only readInts is valid."); } @Override - public void readDoubles(int total, ColumnVector c, int rowId) { + public void readDoubles(int total, WritableColumnVector c, int rowId) { throw new UnsupportedOperationException("only readInts is valid."); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java index 88418ca53fe1e..57d92ae27ece8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.parquet; -import org.apache.spark.sql.execution.vectorized.ColumnVector; +import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.parquet.io.api.Binary; @@ -37,11 +37,11 @@ public interface VectorizedValuesReader { /* * Reads `total` values into `c` start at `c[rowId]` */ - void readBooleans(int total, ColumnVector c, int rowId); - void readBytes(int total, ColumnVector c, int rowId); - void readIntegers(int total, ColumnVector c, int rowId); - void readLongs(int total, ColumnVector c, int rowId); - void readFloats(int total, ColumnVector c, int rowId); - void readDoubles(int total, ColumnVector c, int rowId); - void readBinary(int total, ColumnVector c, int rowId); + void readBooleans(int total, WritableColumnVector c, int rowId); + void readBytes(int total, WritableColumnVector c, int rowId); + void readIntegers(int total, WritableColumnVector c, int rowId); + void readLongs(int total, WritableColumnVector c, int rowId); + void readFloats(int total, WritableColumnVector c, int rowId); + void readDoubles(int total, WritableColumnVector c, int rowId); + void readBinary(int total, WritableColumnVector c, int rowId); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java index 25a565d32638d..1c94f706dc685 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java @@ -41,6 +41,7 @@ */ public class AggregateHashMap { + private OnHeapColumnVector[] columnVectors; private ColumnarBatch batch; private int[] buckets; private int numBuckets; @@ -62,7 +63,8 @@ public AggregateHashMap(StructType schema, int capacity, double loadFactor, int this.maxSteps = maxSteps; numBuckets = (int) (capacity / loadFactor); - batch = ColumnarBatch.allocate(schema, MemoryMode.ON_HEAP, capacity); + columnVectors = OnHeapColumnVector.allocateColumns(capacity, schema); + batch = new ColumnarBatch(schema, columnVectors, capacity); buckets = new int[numBuckets]; Arrays.fill(buckets, -1); } @@ -74,8 +76,8 @@ public AggregateHashMap(StructType schema) { public ColumnarBatch.Row findOrInsert(long key) { int idx = find(key); if (idx != -1 && buckets[idx] == -1) { - batch.column(0).putLong(numRows, key); - batch.column(1).putLong(numRows, 0); + columnVectors[0].putLong(numRows, key); + columnVectors[1].putLong(numRows, 0); buckets[idx] = numRows++; } return batch.getRow(buckets[idx]); @@ -105,6 +107,6 @@ private long hash(long key) { } private boolean equals(int idx, long key1) { - return batch.column(0).getLong(buckets[idx]) == key1; + return columnVectors[0].getLong(buckets[idx]) == key1; } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index 59d66c599c518..be2a9c246747c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -29,12 +29,13 @@ /** * A column vector backed by Apache Arrow. */ -public final class ArrowColumnVector extends ReadOnlyColumnVector { +public final class ArrowColumnVector extends ColumnVector { private final ArrowVectorAccessor accessor; - private final int valueCount; + private ArrowColumnVector[] childColumns; private void ensureAccessible(int index) { + int valueCount = accessor.getValueCount(); if (index < 0 || index >= valueCount) { throw new IndexOutOfBoundsException( String.format("index: %d, valueCount: %d", index, valueCount)); @@ -42,12 +43,23 @@ private void ensureAccessible(int index) { } private void ensureAccessible(int index, int count) { + int valueCount = accessor.getValueCount(); if (index < 0 || index + count > valueCount) { throw new IndexOutOfBoundsException( String.format("index range: [%d, %d), valueCount: %d", index, index + count, valueCount)); } } + @Override + public int numNulls() { + return accessor.getNullCount(); + } + + @Override + public boolean anyNullsSet() { + return numNulls() > 0; + } + @Override public long nullsNativeAddress() { throw new RuntimeException("Cannot get native address for arrow column"); @@ -274,9 +286,20 @@ public byte[] getBinary(int rowId) { return accessor.getBinary(rowId); } + /** + * Returns the data for the underlying array. + */ + @Override + public ArrowColumnVector arrayData() { return childColumns[0]; } + + /** + * Returns the ordinal's child data column. + */ + @Override + public ArrowColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } + public ArrowColumnVector(ValueVector vector) { - super(vector.getValueCapacity(), ArrowUtils.fromArrowField(vector.getField()), - MemoryMode.OFF_HEAP); + super(ArrowUtils.fromArrowField(vector.getField())); if (vector instanceof NullableBitVector) { accessor = new BooleanAccessor((NullableBitVector) vector); @@ -302,7 +325,7 @@ public ArrowColumnVector(ValueVector vector) { ListVector listVector = (ListVector) vector; accessor = new ArrayAccessor(listVector); - childColumns = new ColumnVector[1]; + childColumns = new ArrowColumnVector[1]; childColumns[0] = new ArrowColumnVector(listVector.getDataVector()); resultArray = new ColumnVector.Array(childColumns[0]); } else if (vector instanceof MapVector) { @@ -317,9 +340,6 @@ public ArrowColumnVector(ValueVector vector) { } else { throw new UnsupportedOperationException(); } - valueCount = accessor.getValueCount(); - numNulls = accessor.getNullCount(); - anyNullsSet = numNulls > 0; } private abstract static class ArrowVectorAccessor { @@ -327,14 +347,9 @@ private abstract static class ArrowVectorAccessor { private final ValueVector vector; private final ValueVector.Accessor nulls; - private final int valueCount; - private final int nullCount; - ArrowVectorAccessor(ValueVector vector) { this.vector = vector; this.nulls = vector.getAccessor(); - this.valueCount = nulls.getValueCount(); - this.nullCount = nulls.getNullCount(); } final boolean isNullAt(int rowId) { @@ -342,11 +357,11 @@ final boolean isNullAt(int rowId) { } final int getValueCount() { - return valueCount; + return nulls.getValueCount(); } final int getNullCount() { - return nullCount; + return nulls.getNullCount(); } final void close() { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 77966382881b8..a69dd9718fe33 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -16,23 +16,16 @@ */ package org.apache.spark.sql.execution.vectorized; -import java.math.BigDecimal; -import java.math.BigInteger; - -import com.google.common.annotations.VisibleForTesting; - -import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; -import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; /** * This class represents a column of values and provides the main APIs to access the data - * values. It supports all the types and contains get/put APIs as well as their batched versions. + * values. It supports all the types and contains get APIs as well as their batched versions. * The batched versions are preferable whenever possible. * * To handle nested schemas, ColumnVector has two types: Arrays and Structs. In both cases these @@ -40,34 +33,15 @@ * contains nullability, and in the case of Arrays, the lengths and offsets into the child column. * Lengths and offsets are encoded identically to INTs. * Maps are just a special case of a two field struct. - * Strings are handled as an Array of ByteType. - * - * Capacity: The data stored is dense but the arrays are not fixed capacity. It is the - * responsibility of the caller to call reserve() to ensure there is enough room before adding - * elements. This means that the put() APIs do not check as in common cases (i.e. flat schemas), - * the lengths are known up front. * * Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values * in the current RowBatch. * - * A ColumnVector should be considered immutable once originally created. In other words, it is not - * valid to call put APIs after reads until reset() is called. + * A ColumnVector should be considered immutable once originally created. * * ColumnVectors are intended to be reused. */ public abstract class ColumnVector implements AutoCloseable { - /** - * Allocates a column to store elements of `type` on or off heap. - * Capacity is the initial capacity of the vector and it will grow as necessary. Capacity is - * in number of elements, not number of bytes. - */ - public static ColumnVector allocate(int capacity, DataType type, MemoryMode mode) { - if (mode == MemoryMode.OFF_HEAP) { - return new OffHeapColumnVector(capacity, type); - } else { - return new OnHeapColumnVector(capacity, type); - } - } /** * Holder object to return an array. This object is intended to be reused. Callers should @@ -278,75 +252,22 @@ public Object get(int ordinal, DataType dataType) { */ public final DataType dataType() { return type; } - /** - * Resets this column for writing. The currently stored values are no longer accessible. - */ - public void reset() { - if (isConstant) return; - - if (childColumns != null) { - for (ColumnVector c: childColumns) { - c.reset(); - } - } - numNulls = 0; - elementsAppended = 0; - if (anyNullsSet) { - putNotNulls(0, capacity); - anyNullsSet = false; - } - } - /** * Cleans up memory for this column. The column is not usable after this. * TODO: this should probably have ref-counted semantics. */ public abstract void close(); - public void reserve(int requiredCapacity) { - if (requiredCapacity > capacity) { - int newCapacity = (int) Math.min(MAX_CAPACITY, requiredCapacity * 2L); - if (requiredCapacity <= newCapacity) { - try { - reserveInternal(newCapacity); - } catch (OutOfMemoryError outOfMemoryError) { - throwUnsupportedException(requiredCapacity, outOfMemoryError); - } - } else { - throwUnsupportedException(requiredCapacity, null); - } - } - } - - private void throwUnsupportedException(int requiredCapacity, Throwable cause) { - String message = "Cannot reserve additional contiguous bytes in the vectorized reader " + - "(requested = " + requiredCapacity + " bytes). As a workaround, you can disable the " + - "vectorized reader by setting " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + - " to false."; - - if (cause != null) { - throw new RuntimeException(message, cause); - } else { - throw new RuntimeException(message); - } - } - - /** - * Ensures that there is enough storage to store capacity elements. That is, the put() APIs - * must work for all rowIds < capacity. - */ - protected abstract void reserveInternal(int capacity); - /** * Returns the number of nulls in this column. */ - public final int numNulls() { return numNulls; } + public abstract int numNulls(); /** * Returns true if any of the nulls indicator are set for this column. This can be used * as an optimization to prevent setting nulls. */ - public final boolean anyNullsSet() { return anyNullsSet; } + public abstract boolean anyNullsSet(); /** * Returns the off heap ptr for the arrays backing the NULLs and values buffer. Only valid @@ -355,33 +276,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { public abstract long nullsNativeAddress(); public abstract long valuesNativeAddress(); - /** - * Sets the value at rowId to null/not null. - */ - public abstract void putNotNull(int rowId); - public abstract void putNull(int rowId); - - /** - * Sets the values from [rowId, rowId + count) to null/not null. - */ - public abstract void putNulls(int rowId, int count); - public abstract void putNotNulls(int rowId, int count); - /** * Returns whether the value at rowId is NULL. */ public abstract boolean isNullAt(int rowId); - /** - * Sets the value at rowId to `value`. - */ - public abstract void putBoolean(int rowId, boolean value); - - /** - * Sets values from [rowId, rowId + count) to value. - */ - public abstract void putBooleans(int rowId, int count, boolean value); - /** * Returns the value for rowId. */ @@ -392,21 +291,6 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract boolean[] getBooleans(int rowId, int count); - /** - * Sets the value at rowId to `value`. - */ - public abstract void putByte(int rowId, byte value); - - /** - * Sets values from [rowId, rowId + count) to value. - */ - public abstract void putBytes(int rowId, int count, byte value); - - /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) - */ - public abstract void putBytes(int rowId, int count, byte[] src, int srcIndex); - /** * Returns the value for rowId. */ @@ -417,21 +301,6 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract byte[] getBytes(int rowId, int count); - /** - * Sets the value at rowId to `value`. - */ - public abstract void putShort(int rowId, short value); - - /** - * Sets values from [rowId, rowId + count) to value. - */ - public abstract void putShorts(int rowId, int count, short value); - - /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) - */ - public abstract void putShorts(int rowId, int count, short[] src, int srcIndex); - /** * Returns the value for rowId. */ @@ -442,27 +311,6 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract short[] getShorts(int rowId, int count); - /** - * Sets the value at rowId to `value`. - */ - public abstract void putInt(int rowId, int value); - - /** - * Sets values from [rowId, rowId + count) to value. - */ - public abstract void putInts(int rowId, int count, int value); - - /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) - */ - public abstract void putInts(int rowId, int count, int[] src, int srcIndex); - - /** - * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) - * The data in src must be 4-byte little endian ints. - */ - public abstract void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex); - /** * Returns the value for rowId. */ @@ -480,27 +328,6 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract int getDictId(int rowId); - /** - * Sets the value at rowId to `value`. - */ - public abstract void putLong(int rowId, long value); - - /** - * Sets values from [rowId, rowId + count) to value. - */ - public abstract void putLongs(int rowId, int count, long value); - - /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) - */ - public abstract void putLongs(int rowId, int count, long[] src, int srcIndex); - - /** - * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) - * The data in src must be 8-byte little endian longs. - */ - public abstract void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex); - /** * Returns the value for rowId. */ @@ -511,27 +338,6 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract long[] getLongs(int rowId, int count); - /** - * Sets the value at rowId to `value`. - */ - public abstract void putFloat(int rowId, float value); - - /** - * Sets values from [rowId, rowId + count) to value. - */ - public abstract void putFloats(int rowId, int count, float value); - - /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) - */ - public abstract void putFloats(int rowId, int count, float[] src, int srcIndex); - - /** - * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) - * The data in src must be ieee formatted floats. - */ - public abstract void putFloats(int rowId, int count, byte[] src, int srcIndex); - /** * Returns the value for rowId. */ @@ -542,27 +348,6 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract float[] getFloats(int rowId, int count); - /** - * Sets the value at rowId to `value`. - */ - public abstract void putDouble(int rowId, double value); - - /** - * Sets values from [rowId, rowId + count) to value. - */ - public abstract void putDoubles(int rowId, int count, double value); - - /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) - */ - public abstract void putDoubles(int rowId, int count, double[] src, int srcIndex); - - /** - * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) - * The data in src must be ieee formatted doubles. - */ - public abstract void putDoubles(int rowId, int count, byte[] src, int srcIndex); - /** * Returns the value for rowId. */ @@ -573,11 +358,6 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { */ public abstract double[] getDoubles(int rowId, int count); - /** - * Puts a byte array that already exists in this column. - */ - public abstract void putArray(int rowId, int offset, int length); - /** * Returns the length of the array at rowid. */ @@ -608,7 +388,7 @@ public ColumnarBatch.Row getStruct(int rowId, int size) { /** * Returns the array at rowid. */ - public final Array getArray(int rowId) { + public final ColumnVector.Array getArray(int rowId) { resultArray.length = getArrayLength(rowId); resultArray.offset = getArrayOffset(rowId); return resultArray; @@ -617,24 +397,7 @@ public final Array getArray(int rowId) { /** * Loads the data into array.byteArray. */ - public abstract void loadBytes(Array array); - - /** - * Sets the value at rowId to `value`. - */ - public abstract int putByteArray(int rowId, byte[] value, int offset, int count); - public final int putByteArray(int rowId, byte[] value) { - return putByteArray(rowId, value, 0, value.length); - } - - /** - * Returns the value for rowId. - */ - private Array getByteArray(int rowId) { - Array array = getArray(rowId); - array.data.loadBytes(array); - return array; - } + public abstract void loadBytes(ColumnVector.Array array); /** * Returns the value for rowId. @@ -646,354 +409,42 @@ public MapData getMap(int ordinal) { /** * Returns the decimal for rowId. */ - public Decimal getDecimal(int rowId, int precision, int scale) { - if (precision <= Decimal.MAX_INT_DIGITS()) { - return Decimal.createUnsafe(getInt(rowId), precision, scale); - } else if (precision <= Decimal.MAX_LONG_DIGITS()) { - return Decimal.createUnsafe(getLong(rowId), precision, scale); - } else { - // TODO: best perf? - byte[] bytes = getBinary(rowId); - BigInteger bigInteger = new BigInteger(bytes); - BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); - return Decimal.apply(javaDecimal, precision, scale); - } - } - - - public void putDecimal(int rowId, Decimal value, int precision) { - if (precision <= Decimal.MAX_INT_DIGITS()) { - putInt(rowId, (int) value.toUnscaledLong()); - } else if (precision <= Decimal.MAX_LONG_DIGITS()) { - putLong(rowId, value.toUnscaledLong()); - } else { - BigInteger bigInteger = value.toJavaBigDecimal().unscaledValue(); - putByteArray(rowId, bigInteger.toByteArray()); - } - } + public abstract Decimal getDecimal(int rowId, int precision, int scale); /** * Returns the UTF8String for rowId. */ - public UTF8String getUTF8String(int rowId) { - if (dictionary == null) { - ColumnVector.Array a = getByteArray(rowId); - return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); - } else { - byte[] bytes = dictionary.decodeToBinary(dictionaryIds.getDictId(rowId)); - return UTF8String.fromBytes(bytes); - } - } + public abstract UTF8String getUTF8String(int rowId); /** * Returns the byte array for rowId. */ - public byte[] getBinary(int rowId) { - if (dictionary == null) { - ColumnVector.Array array = getByteArray(rowId); - byte[] bytes = new byte[array.length]; - System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); - return bytes; - } else { - return dictionary.decodeToBinary(dictionaryIds.getDictId(rowId)); - } - } - - /** - * Append APIs. These APIs all behave similarly and will append data to the current vector. It - * is not valid to mix the put and append APIs. The append APIs are slower and should only be - * used if the sizes are not known up front. - * In all these cases, the return value is the rowId for the first appended element. - */ - public final int appendNull() { - assert (!(dataType() instanceof StructType)); // Use appendStruct() - reserve(elementsAppended + 1); - putNull(elementsAppended); - return elementsAppended++; - } - - public final int appendNotNull() { - reserve(elementsAppended + 1); - putNotNull(elementsAppended); - return elementsAppended++; - } - - public final int appendNulls(int count) { - assert (!(dataType() instanceof StructType)); - reserve(elementsAppended + count); - int result = elementsAppended; - putNulls(elementsAppended, count); - elementsAppended += count; - return result; - } - - public final int appendNotNulls(int count) { - assert (!(dataType() instanceof StructType)); - reserve(elementsAppended + count); - int result = elementsAppended; - putNotNulls(elementsAppended, count); - elementsAppended += count; - return result; - } - - public final int appendBoolean(boolean v) { - reserve(elementsAppended + 1); - putBoolean(elementsAppended, v); - return elementsAppended++; - } - - public final int appendBooleans(int count, boolean v) { - reserve(elementsAppended + count); - int result = elementsAppended; - putBooleans(elementsAppended, count, v); - elementsAppended += count; - return result; - } - - public final int appendByte(byte v) { - reserve(elementsAppended + 1); - putByte(elementsAppended, v); - return elementsAppended++; - } - - public final int appendBytes(int count, byte v) { - reserve(elementsAppended + count); - int result = elementsAppended; - putBytes(elementsAppended, count, v); - elementsAppended += count; - return result; - } - - public final int appendBytes(int length, byte[] src, int offset) { - reserve(elementsAppended + length); - int result = elementsAppended; - putBytes(elementsAppended, length, src, offset); - elementsAppended += length; - return result; - } - - public final int appendShort(short v) { - reserve(elementsAppended + 1); - putShort(elementsAppended, v); - return elementsAppended++; - } - - public final int appendShorts(int count, short v) { - reserve(elementsAppended + count); - int result = elementsAppended; - putShorts(elementsAppended, count, v); - elementsAppended += count; - return result; - } - - public final int appendShorts(int length, short[] src, int offset) { - reserve(elementsAppended + length); - int result = elementsAppended; - putShorts(elementsAppended, length, src, offset); - elementsAppended += length; - return result; - } - - public final int appendInt(int v) { - reserve(elementsAppended + 1); - putInt(elementsAppended, v); - return elementsAppended++; - } - - public final int appendInts(int count, int v) { - reserve(elementsAppended + count); - int result = elementsAppended; - putInts(elementsAppended, count, v); - elementsAppended += count; - return result; - } - - public final int appendInts(int length, int[] src, int offset) { - reserve(elementsAppended + length); - int result = elementsAppended; - putInts(elementsAppended, length, src, offset); - elementsAppended += length; - return result; - } - - public final int appendLong(long v) { - reserve(elementsAppended + 1); - putLong(elementsAppended, v); - return elementsAppended++; - } - - public final int appendLongs(int count, long v) { - reserve(elementsAppended + count); - int result = elementsAppended; - putLongs(elementsAppended, count, v); - elementsAppended += count; - return result; - } - - public final int appendLongs(int length, long[] src, int offset) { - reserve(elementsAppended + length); - int result = elementsAppended; - putLongs(elementsAppended, length, src, offset); - elementsAppended += length; - return result; - } - - public final int appendFloat(float v) { - reserve(elementsAppended + 1); - putFloat(elementsAppended, v); - return elementsAppended++; - } - - public final int appendFloats(int count, float v) { - reserve(elementsAppended + count); - int result = elementsAppended; - putFloats(elementsAppended, count, v); - elementsAppended += count; - return result; - } - - public final int appendFloats(int length, float[] src, int offset) { - reserve(elementsAppended + length); - int result = elementsAppended; - putFloats(elementsAppended, length, src, offset); - elementsAppended += length; - return result; - } - - public final int appendDouble(double v) { - reserve(elementsAppended + 1); - putDouble(elementsAppended, v); - return elementsAppended++; - } - - public final int appendDoubles(int count, double v) { - reserve(elementsAppended + count); - int result = elementsAppended; - putDoubles(elementsAppended, count, v); - elementsAppended += count; - return result; - } - - public final int appendDoubles(int length, double[] src, int offset) { - reserve(elementsAppended + length); - int result = elementsAppended; - putDoubles(elementsAppended, length, src, offset); - elementsAppended += length; - return result; - } - - public final int appendByteArray(byte[] value, int offset, int length) { - int copiedOffset = arrayData().appendBytes(length, value, offset); - reserve(elementsAppended + 1); - putArray(elementsAppended, copiedOffset, length); - return elementsAppended++; - } - - public final int appendArray(int length) { - reserve(elementsAppended + 1); - putArray(elementsAppended, arrayData().elementsAppended, length); - return elementsAppended++; - } - - /** - * Appends a NULL struct. This *has* to be used for structs instead of appendNull() as this - * recursively appends a NULL to its children. - * We don't have this logic as the general appendNull implementation to optimize the more - * common non-struct case. - */ - public final int appendStruct(boolean isNull) { - if (isNull) { - appendNull(); - for (ColumnVector c: childColumns) { - if (c.type instanceof StructType) { - c.appendStruct(true); - } else { - c.appendNull(); - } - } - } else { - appendNotNull(); - } - return elementsAppended; - } + public abstract byte[] getBinary(int rowId); /** * Returns the data for the underlying array. */ - public final ColumnVector arrayData() { return childColumns[0]; } + public abstract ColumnVector arrayData(); /** * Returns the ordinal's child data column. */ - public final ColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } - - /** - * Returns the elements appended. - */ - public final int getElementsAppended() { return elementsAppended; } + public abstract ColumnVector getChildColumn(int ordinal); /** * Returns true if this column is an array. */ public final boolean isArray() { return resultArray != null; } - /** - * Marks this column as being constant. - */ - public final void setIsConstant() { isConstant = true; } - - /** - * Maximum number of rows that can be stored in this column. - */ - protected int capacity; - - /** - * Upper limit for the maximum capacity for this column. - */ - @VisibleForTesting - protected int MAX_CAPACITY = Integer.MAX_VALUE; - /** * Data type for this column. */ protected DataType type; - /** - * Number of nulls in this column. This is an optimization for the reader, to skip NULL checks. - */ - protected int numNulls; - - /** - * True if there is at least one NULL byte set. This is an optimization for the writer, to skip - * having to clear NULL bits. - */ - protected boolean anyNullsSet; - - /** - * True if this column's values are fixed. This means the column values never change, even - * across resets. - */ - protected boolean isConstant; - - /** - * Default size of each array length value. This grows as necessary. - */ - protected static final int DEFAULT_ARRAY_LENGTH = 4; - - /** - * Current write cursor (row index) when appending data. - */ - protected int elementsAppended; - - /** - * If this is a nested type (array or struct), the column for the child data. - */ - protected ColumnVector[] childColumns; - /** * Reusable Array holder for getArray(). */ - protected Array resultArray; + protected ColumnVector.Array resultArray; /** * Reusable Struct holder for getStruct(). @@ -1012,32 +463,11 @@ public final int appendStruct(boolean isNull) { */ protected ColumnVector dictionaryIds; - /** - * Update the dictionary. - */ - public void setDictionary(Dictionary dictionary) { - this.dictionary = dictionary; - } - /** * Returns true if this column has a dictionary. */ public boolean hasDictionary() { return this.dictionary != null; } - /** - * Reserve a integer column for ids of dictionary. - */ - public ColumnVector reserveDictionaryIds(int capacity) { - if (dictionaryIds == null) { - dictionaryIds = allocate(capacity, DataTypes.IntegerType, - this instanceof OnHeapColumnVector ? MemoryMode.ON_HEAP : MemoryMode.OFF_HEAP); - } else { - dictionaryIds.reset(); - dictionaryIds.reserve(capacity); - } - return dictionaryIds; - } - /** * Returns the underlying integer column for ids of dictionary. */ @@ -1049,43 +479,7 @@ public ColumnVector getDictionaryIds() { * Sets up the common state and also handles creating the child columns if this is a nested * type. */ - protected ColumnVector(int capacity, DataType type, MemoryMode memMode) { - this.capacity = capacity; + protected ColumnVector(DataType type) { this.type = type; - - if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType - || DecimalType.isByteArrayDecimalType(type)) { - DataType childType; - int childCapacity = capacity; - if (type instanceof ArrayType) { - childType = ((ArrayType)type).elementType(); - } else { - childType = DataTypes.ByteType; - childCapacity *= DEFAULT_ARRAY_LENGTH; - } - this.childColumns = new ColumnVector[1]; - this.childColumns[0] = ColumnVector.allocate(childCapacity, childType, memMode); - this.resultArray = new Array(this.childColumns[0]); - this.resultStruct = null; - } else if (type instanceof StructType) { - StructType st = (StructType)type; - this.childColumns = new ColumnVector[st.fields().length]; - for (int i = 0; i < childColumns.length; ++i) { - this.childColumns[i] = ColumnVector.allocate(capacity, st.fields()[i].dataType(), memMode); - } - this.resultArray = null; - this.resultStruct = new ColumnarBatch.Row(this.childColumns); - } else if (type instanceof CalendarIntervalType) { - // Two columns. Months as int. Microseconds as Long. - this.childColumns = new ColumnVector[2]; - this.childColumns[0] = ColumnVector.allocate(capacity, DataTypes.IntegerType, memMode); - this.childColumns[1] = ColumnVector.allocate(capacity, DataTypes.LongType, memMode); - this.resultArray = null; - this.resultStruct = new ColumnarBatch.Row(this.childColumns); - } else { - this.childColumns = null; - this.resultArray = null; - this.resultStruct = null; - } } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index 900d7c431e723..adb859ed17757 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -40,7 +40,7 @@ public class ColumnVectorUtils { /** * Populates the entire `col` with `row[fieldIdx]` */ - public static void populate(ColumnVector col, InternalRow row, int fieldIdx) { + public static void populate(WritableColumnVector col, InternalRow row, int fieldIdx) { int capacity = col.capacity; DataType t = col.dataType(); @@ -115,7 +115,7 @@ public static Object toPrimitiveJavaArray(ColumnVector.Array array) { } } - private static void appendValue(ColumnVector dst, DataType t, Object o) { + private static void appendValue(WritableColumnVector dst, DataType t, Object o) { if (o == null) { if (t instanceof CalendarIntervalType) { dst.appendStruct(true); @@ -165,7 +165,7 @@ private static void appendValue(ColumnVector dst, DataType t, Object o) { } } - private static void appendValue(ColumnVector dst, DataType t, Row src, int fieldIdx) { + private static void appendValue(WritableColumnVector dst, DataType t, Row src, int fieldIdx) { if (t instanceof ArrayType) { ArrayType at = (ArrayType)t; if (src.isNullAt(fieldIdx)) { @@ -198,15 +198,23 @@ private static void appendValue(ColumnVector dst, DataType t, Row src, int field */ public static ColumnarBatch toBatch( StructType schema, MemoryMode memMode, Iterator row) { - ColumnarBatch batch = ColumnarBatch.allocate(schema, memMode); + int capacity = ColumnarBatch.DEFAULT_BATCH_SIZE; + WritableColumnVector[] columnVectors; + if (memMode == MemoryMode.OFF_HEAP) { + columnVectors = OffHeapColumnVector.allocateColumns(capacity, schema); + } else { + columnVectors = OnHeapColumnVector.allocateColumns(capacity, schema); + } + int n = 0; while (row.hasNext()) { Row r = row.next(); for (int i = 0; i < schema.fields().length; i++) { - appendValue(batch.column(i), schema.fields()[i].dataType(), r, i); + appendValue(columnVectors[i], schema.fields()[i].dataType(), r, i); } n++; } + ColumnarBatch batch = new ColumnarBatch(schema, columnVectors, capacity); batch.setNumRows(n); return batch; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 34dc3af9b85c8..e782756a3e781 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -19,7 +19,6 @@ import java.math.BigDecimal; import java.util.*; -import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; @@ -44,8 +43,7 @@ * - Compaction: The batch and columns should be able to compact based on a selection vector. */ public final class ColumnarBatch { - private static final int DEFAULT_BATCH_SIZE = 4 * 1024; - private static MemoryMode DEFAULT_MEMORY_MODE = MemoryMode.ON_HEAP; + public static final int DEFAULT_BATCH_SIZE = 4 * 1024; private final StructType schema; private final int capacity; @@ -64,18 +62,6 @@ public final class ColumnarBatch { // Staging row returned from getRow. final Row row; - public static ColumnarBatch allocate(StructType schema, MemoryMode memMode) { - return new ColumnarBatch(schema, DEFAULT_BATCH_SIZE, memMode); - } - - public static ColumnarBatch allocate(StructType type) { - return new ColumnarBatch(type, DEFAULT_BATCH_SIZE, DEFAULT_MEMORY_MODE); - } - - public static ColumnarBatch allocate(StructType schema, MemoryMode memMode, int maxRows) { - return new ColumnarBatch(schema, maxRows, memMode); - } - /** * Called to close all the columns in this batch. It is not valid to access the data after * calling this. This must be called at the end to clean up memory allocations. @@ -95,12 +81,19 @@ public static final class Row extends InternalRow { private final ColumnarBatch parent; private final int fixedLenRowSize; private final ColumnVector[] columns; + private final WritableColumnVector[] writableColumns; // Ctor used if this is a top level row. private Row(ColumnarBatch parent) { this.parent = parent; this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(parent.numCols()); this.columns = parent.columns; + this.writableColumns = new WritableColumnVector[this.columns.length]; + for (int i = 0; i < this.columns.length; i++) { + if (this.columns[i] instanceof WritableColumnVector) { + this.writableColumns[i] = (WritableColumnVector) this.columns[i]; + } + } } // Ctor used if this is a struct. @@ -108,6 +101,12 @@ protected Row(ColumnVector[] columns) { this.parent = null; this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(columns.length); this.columns = columns; + this.writableColumns = new WritableColumnVector[this.columns.length]; + for (int i = 0; i < this.columns.length; i++) { + if (this.columns[i] instanceof WritableColumnVector) { + this.writableColumns[i] = (WritableColumnVector) this.columns[i]; + } + } } /** @@ -307,64 +306,69 @@ public void update(int ordinal, Object value) { @Override public void setNullAt(int ordinal) { - assert (!columns[ordinal].isConstant); - columns[ordinal].putNull(rowId); + getWritableColumn(ordinal).putNull(rowId); } @Override public void setBoolean(int ordinal, boolean value) { - assert (!columns[ordinal].isConstant); - columns[ordinal].putNotNull(rowId); - columns[ordinal].putBoolean(rowId, value); + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putBoolean(rowId, value); } @Override public void setByte(int ordinal, byte value) { - assert (!columns[ordinal].isConstant); - columns[ordinal].putNotNull(rowId); - columns[ordinal].putByte(rowId, value); + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putByte(rowId, value); } @Override public void setShort(int ordinal, short value) { - assert (!columns[ordinal].isConstant); - columns[ordinal].putNotNull(rowId); - columns[ordinal].putShort(rowId, value); + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putShort(rowId, value); } @Override public void setInt(int ordinal, int value) { - assert (!columns[ordinal].isConstant); - columns[ordinal].putNotNull(rowId); - columns[ordinal].putInt(rowId, value); + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putInt(rowId, value); } @Override public void setLong(int ordinal, long value) { - assert (!columns[ordinal].isConstant); - columns[ordinal].putNotNull(rowId); - columns[ordinal].putLong(rowId, value); + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putLong(rowId, value); } @Override public void setFloat(int ordinal, float value) { - assert (!columns[ordinal].isConstant); - columns[ordinal].putNotNull(rowId); - columns[ordinal].putFloat(rowId, value); + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putFloat(rowId, value); } @Override public void setDouble(int ordinal, double value) { - assert (!columns[ordinal].isConstant); - columns[ordinal].putNotNull(rowId); - columns[ordinal].putDouble(rowId, value); + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putDouble(rowId, value); } @Override public void setDecimal(int ordinal, Decimal value, int precision) { - assert (!columns[ordinal].isConstant); - columns[ordinal].putNotNull(rowId); - columns[ordinal].putDecimal(rowId, value, precision); + WritableColumnVector column = getWritableColumn(ordinal); + column.putNotNull(rowId); + column.putDecimal(rowId, value, precision); + } + + private WritableColumnVector getWritableColumn(int ordinal) { + WritableColumnVector column = writableColumns[ordinal]; + assert (!column.isConstant); + return column; } } @@ -409,7 +413,9 @@ public void remove() { */ public void reset() { for (int i = 0; i < numCols(); ++i) { - columns[i].reset(); + if (columns[i] instanceof WritableColumnVector) { + ((WritableColumnVector) columns[i]).reset(); + } } if (this.numRowsFiltered > 0) { Arrays.fill(filteredRows, false); @@ -427,7 +433,7 @@ public void setNumRows(int numRows) { this.numRows = numRows; for (int ordinal : nullFilteredColumns) { - if (columns[ordinal].numNulls != 0) { + if (columns[ordinal].numNulls() != 0) { for (int rowId = 0; rowId < numRows; rowId++) { if (!filteredRows[rowId] && columns[ordinal].isNullAt(rowId)) { filteredRows[rowId] = true; @@ -505,18 +511,12 @@ public void filterNullsInColumn(int ordinal) { nullFilteredColumns.add(ordinal); } - private ColumnarBatch(StructType schema, int maxRows, MemoryMode memMode) { + public ColumnarBatch(StructType schema, ColumnVector[] columns, int capacity) { this.schema = schema; - this.capacity = maxRows; - this.columns = new ColumnVector[schema.size()]; + this.columns = columns; + this.capacity = capacity; this.nullFilteredColumns = new HashSet<>(); - this.filteredRows = new boolean[maxRows]; - - for (int i = 0; i < schema.fields().length; ++i) { - StructField field = schema.fields()[i]; - columns[i] = ColumnVector.allocate(maxRows, field.dataType(), memMode); - } - + this.filteredRows = new boolean[capacity]; this.row = new Row(this); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 2d1f3da8e7463..35682756ed6c3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -19,18 +19,39 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; -import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; /** * Column data backed using offheap memory. */ -public final class OffHeapColumnVector extends ColumnVector { +public final class OffHeapColumnVector extends WritableColumnVector { private static final boolean bigEndianPlatform = ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN); + /** + * Allocates columns to store elements of each field of the schema off heap. + * Capacity is the initial capacity of the vector and it will grow as necessary. Capacity is + * in number of elements, not number of bytes. + */ + public static OffHeapColumnVector[] allocateColumns(int capacity, StructType schema) { + return allocateColumns(capacity, schema.fields()); + } + + /** + * Allocates columns to store elements of each field off heap. + * Capacity is the initial capacity of the vector and it will grow as necessary. Capacity is + * in number of elements, not number of bytes. + */ + public static OffHeapColumnVector[] allocateColumns(int capacity, StructField[] fields) { + OffHeapColumnVector[] vectors = new OffHeapColumnVector[fields.length]; + for (int i = 0; i < fields.length; i++) { + vectors[i] = new OffHeapColumnVector(capacity, fields[i].dataType()); + } + return vectors; + } + // The data stored in these two allocations need to maintain binary compatible. We can // directly pass this buffer to external components. private long nulls; @@ -40,8 +61,8 @@ public final class OffHeapColumnVector extends ColumnVector { private long lengthData; private long offsetData; - protected OffHeapColumnVector(int capacity, DataType type) { - super(capacity, type, MemoryMode.OFF_HEAP); + public OffHeapColumnVector(int capacity, DataType type) { + super(capacity, type); nulls = 0; data = 0; @@ -519,4 +540,9 @@ protected void reserveInternal(int newCapacity) { Platform.setMemory(nulls + oldCapacity, (byte)0, newCapacity - oldCapacity); capacity = newCapacity; } + + @Override + protected OffHeapColumnVector reserveNewColumn(int capacity, DataType type) { + return new OffHeapColumnVector(capacity, type); + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 506434364be48..96a452978cb35 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -20,7 +20,6 @@ import java.nio.ByteOrder; import java.util.Arrays; -import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; @@ -28,11 +27,33 @@ * A column backed by an in memory JVM array. This stores the NULLs as a byte per value * and a java array for the values. */ -public final class OnHeapColumnVector extends ColumnVector { +public final class OnHeapColumnVector extends WritableColumnVector { private static final boolean bigEndianPlatform = ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN); + /** + * Allocates columns to store elements of each field of the schema on heap. + * Capacity is the initial capacity of the vector and it will grow as necessary. Capacity is + * in number of elements, not number of bytes. + */ + public static OnHeapColumnVector[] allocateColumns(int capacity, StructType schema) { + return allocateColumns(capacity, schema.fields()); + } + + /** + * Allocates columns to store elements of each field on heap. + * Capacity is the initial capacity of the vector and it will grow as necessary. Capacity is + * in number of elements, not number of bytes. + */ + public static OnHeapColumnVector[] allocateColumns(int capacity, StructField[] fields) { + OnHeapColumnVector[] vectors = new OnHeapColumnVector[fields.length]; + for (int i = 0; i < fields.length; i++) { + vectors[i] = new OnHeapColumnVector(capacity, fields[i].dataType()); + } + return vectors; + } + // The data stored in these arrays need to maintain binary compatible. We can // directly pass this buffer to external components. @@ -51,8 +72,9 @@ public final class OnHeapColumnVector extends ColumnVector { private int[] arrayLengths; private int[] arrayOffsets; - protected OnHeapColumnVector(int capacity, DataType type) { - super(capacity, type, MemoryMode.ON_HEAP); + public OnHeapColumnVector(int capacity, DataType type) { + super(capacity, type); + reserveInternal(capacity); reset(); } @@ -529,4 +551,9 @@ protected void reserveInternal(int newCapacity) { capacity = newCapacity; } + + @Override + protected OnHeapColumnVector reserveNewColumn(int capacity, DataType type) { + return new OnHeapColumnVector(capacity, type); + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ReadOnlyColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ReadOnlyColumnVector.java deleted file mode 100644 index e9f6e7c631fd4..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ReadOnlyColumnVector.java +++ /dev/null @@ -1,251 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.vectorized; - -import org.apache.spark.memory.MemoryMode; -import org.apache.spark.sql.types.*; - -/** - * An abstract class for read-only column vector. - */ -public abstract class ReadOnlyColumnVector extends ColumnVector { - - protected ReadOnlyColumnVector(int capacity, DataType type, MemoryMode memMode) { - super(capacity, DataTypes.NullType, memMode); - this.type = type; - isConstant = true; - } - - // - // APIs dealing with nulls - // - - @Override - public final void putNotNull(int rowId) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putNull(int rowId) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putNulls(int rowId, int count) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putNotNulls(int rowId, int count) { - throw new UnsupportedOperationException(); - } - - // - // APIs dealing with Booleans - // - - @Override - public final void putBoolean(int rowId, boolean value) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putBooleans(int rowId, int count, boolean value) { - throw new UnsupportedOperationException(); - } - - // - // APIs dealing with Bytes - // - - @Override - public final void putByte(int rowId, byte value) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putBytes(int rowId, int count, byte value) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putBytes(int rowId, int count, byte[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - // - // APIs dealing with Shorts - // - - @Override - public final void putShort(int rowId, short value) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putShorts(int rowId, int count, short value) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putShorts(int rowId, int count, short[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - // - // APIs dealing with Ints - // - - @Override - public final void putInt(int rowId, int value) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putInts(int rowId, int count, int value) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putInts(int rowId, int count, int[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - // - // APIs dealing with Longs - // - - @Override - public final void putLong(int rowId, long value) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putLongs(int rowId, int count, long value) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putLongs(int rowId, int count, long[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - // - // APIs dealing with floats - // - - @Override - public final void putFloat(int rowId, float value) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putFloats(int rowId, int count, float value) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putFloats(int rowId, int count, float[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putFloats(int rowId, int count, byte[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - // - // APIs dealing with doubles - // - - @Override - public final void putDouble(int rowId, double value) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putDoubles(int rowId, int count, double value) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putDoubles(int rowId, int count, double[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - @Override - public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) { - throw new UnsupportedOperationException(); - } - - // - // APIs dealing with Arrays - // - - @Override - public final void putArray(int rowId, int offset, int length) { - throw new UnsupportedOperationException(); - } - - // - // APIs dealing with Byte Arrays - // - - @Override - public final int putByteArray(int rowId, byte[] value, int offset, int count) { - throw new UnsupportedOperationException(); - } - - // - // APIs dealing with Decimals - // - - @Override - public final void putDecimal(int rowId, Decimal value, int precision) { - throw new UnsupportedOperationException(); - } - - // - // Other APIs - // - - @Override - public final void setDictionary(Dictionary dictionary) { - throw new UnsupportedOperationException(); - } - - @Override - public final ColumnVector reserveDictionaryIds(int capacity) { - throw new UnsupportedOperationException(); - } - - @Override - protected final void reserveInternal(int newCapacity) { - throw new UnsupportedOperationException(); - } -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java new file mode 100644 index 0000000000000..b4f753c0bc2a3 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -0,0 +1,674 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import java.math.BigDecimal; +import java.math.BigInteger; + +import com.google.common.annotations.VisibleForTesting; + +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * This class adds write APIs to ColumnVector. + * It supports all the types and contains put APIs as well as their batched versions. + * The batched versions are preferable whenever possible. + * + * Capacity: The data stored is dense but the arrays are not fixed capacity. It is the + * responsibility of the caller to call reserve() to ensure there is enough room before adding + * elements. This means that the put() APIs do not check as in common cases (i.e. flat schemas), + * the lengths are known up front. + * + * A ColumnVector should be considered immutable once originally created. In other words, it is not + * valid to call put APIs after reads until reset() is called. + */ +public abstract class WritableColumnVector extends ColumnVector { + + /** + * Resets this column for writing. The currently stored values are no longer accessible. + */ + public void reset() { + if (isConstant) return; + + if (childColumns != null) { + for (ColumnVector c: childColumns) { + ((WritableColumnVector) c).reset(); + } + } + numNulls = 0; + elementsAppended = 0; + if (anyNullsSet) { + putNotNulls(0, capacity); + anyNullsSet = false; + } + } + + public void reserve(int requiredCapacity) { + if (requiredCapacity > capacity) { + int newCapacity = (int) Math.min(MAX_CAPACITY, requiredCapacity * 2L); + if (requiredCapacity <= newCapacity) { + try { + reserveInternal(newCapacity); + } catch (OutOfMemoryError outOfMemoryError) { + throwUnsupportedException(requiredCapacity, outOfMemoryError); + } + } else { + throwUnsupportedException(requiredCapacity, null); + } + } + } + + private void throwUnsupportedException(int requiredCapacity, Throwable cause) { + String message = "Cannot reserve additional contiguous bytes in the vectorized reader " + + "(requested = " + requiredCapacity + " bytes). As a workaround, you can disable the " + + "vectorized reader by setting " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() + + " to false."; + throw new RuntimeException(message, cause); + } + + @Override + public int numNulls() { return numNulls; } + + @Override + public boolean anyNullsSet() { return anyNullsSet; } + + /** + * Ensures that there is enough storage to store capacity elements. That is, the put() APIs + * must work for all rowIds < capacity. + */ + protected abstract void reserveInternal(int capacity); + + /** + * Sets the value at rowId to null/not null. + */ + public abstract void putNotNull(int rowId); + public abstract void putNull(int rowId); + + /** + * Sets the values from [rowId, rowId + count) to null/not null. + */ + public abstract void putNulls(int rowId, int count); + public abstract void putNotNulls(int rowId, int count); + + /** + * Sets the value at rowId to `value`. + */ + public abstract void putBoolean(int rowId, boolean value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putBooleans(int rowId, int count, boolean value); + + /** + * Sets the value at rowId to `value`. + */ + public abstract void putByte(int rowId, byte value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putBytes(int rowId, int count, byte value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + */ + public abstract void putBytes(int rowId, int count, byte[] src, int srcIndex); + + /** + * Sets the value at rowId to `value`. + */ + public abstract void putShort(int rowId, short value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putShorts(int rowId, int count, short value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + */ + public abstract void putShorts(int rowId, int count, short[] src, int srcIndex); + + /** + * Sets the value at rowId to `value`. + */ + public abstract void putInt(int rowId, int value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putInts(int rowId, int count, int value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + */ + public abstract void putInts(int rowId, int count, int[] src, int srcIndex); + + /** + * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) + * The data in src must be 4-byte little endian ints. + */ + public abstract void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex); + + /** + * Sets the value at rowId to `value`. + */ + public abstract void putLong(int rowId, long value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putLongs(int rowId, int count, long value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + */ + public abstract void putLongs(int rowId, int count, long[] src, int srcIndex); + + /** + * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) + * The data in src must be 8-byte little endian longs. + */ + public abstract void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex); + + /** + * Sets the value at rowId to `value`. + */ + public abstract void putFloat(int rowId, float value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putFloats(int rowId, int count, float value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + */ + public abstract void putFloats(int rowId, int count, float[] src, int srcIndex); + + /** + * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) + * The data in src must be ieee formatted floats. + */ + public abstract void putFloats(int rowId, int count, byte[] src, int srcIndex); + + /** + * Sets the value at rowId to `value`. + */ + public abstract void putDouble(int rowId, double value); + + /** + * Sets values from [rowId, rowId + count) to value. + */ + public abstract void putDoubles(int rowId, int count, double value); + + /** + * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + */ + public abstract void putDoubles(int rowId, int count, double[] src, int srcIndex); + + /** + * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) + * The data in src must be ieee formatted doubles. + */ + public abstract void putDoubles(int rowId, int count, byte[] src, int srcIndex); + + /** + * Puts a byte array that already exists in this column. + */ + public abstract void putArray(int rowId, int offset, int length); + + /** + * Sets the value at rowId to `value`. + */ + public abstract int putByteArray(int rowId, byte[] value, int offset, int count); + public final int putByteArray(int rowId, byte[] value) { + return putByteArray(rowId, value, 0, value.length); + } + + /** + * Returns the value for rowId. + */ + private ColumnVector.Array getByteArray(int rowId) { + ColumnVector.Array array = getArray(rowId); + array.data.loadBytes(array); + return array; + } + + /** + * Returns the decimal for rowId. + */ + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + if (precision <= Decimal.MAX_INT_DIGITS()) { + return Decimal.createUnsafe(getInt(rowId), precision, scale); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + return Decimal.createUnsafe(getLong(rowId), precision, scale); + } else { + // TODO: best perf? + byte[] bytes = getBinary(rowId); + BigInteger bigInteger = new BigInteger(bytes); + BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); + return Decimal.apply(javaDecimal, precision, scale); + } + } + + public void putDecimal(int rowId, Decimal value, int precision) { + if (precision <= Decimal.MAX_INT_DIGITS()) { + putInt(rowId, (int) value.toUnscaledLong()); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + putLong(rowId, value.toUnscaledLong()); + } else { + BigInteger bigInteger = value.toJavaBigDecimal().unscaledValue(); + putByteArray(rowId, bigInteger.toByteArray()); + } + } + + /** + * Returns the UTF8String for rowId. + */ + @Override + public UTF8String getUTF8String(int rowId) { + if (dictionary == null) { + ColumnVector.Array a = getByteArray(rowId); + return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); + } else { + byte[] bytes = dictionary.decodeToBinary(dictionaryIds.getDictId(rowId)); + return UTF8String.fromBytes(bytes); + } + } + + /** + * Returns the byte array for rowId. + */ + @Override + public byte[] getBinary(int rowId) { + if (dictionary == null) { + ColumnVector.Array array = getByteArray(rowId); + byte[] bytes = new byte[array.length]; + System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); + return bytes; + } else { + return dictionary.decodeToBinary(dictionaryIds.getDictId(rowId)); + } + } + + /** + * Append APIs. These APIs all behave similarly and will append data to the current vector. It + * is not valid to mix the put and append APIs. The append APIs are slower and should only be + * used if the sizes are not known up front. + * In all these cases, the return value is the rowId for the first appended element. + */ + public final int appendNull() { + assert (!(dataType() instanceof StructType)); // Use appendStruct() + reserve(elementsAppended + 1); + putNull(elementsAppended); + return elementsAppended++; + } + + public final int appendNotNull() { + reserve(elementsAppended + 1); + putNotNull(elementsAppended); + return elementsAppended++; + } + + public final int appendNulls(int count) { + assert (!(dataType() instanceof StructType)); + reserve(elementsAppended + count); + int result = elementsAppended; + putNulls(elementsAppended, count); + elementsAppended += count; + return result; + } + + public final int appendNotNulls(int count) { + assert (!(dataType() instanceof StructType)); + reserve(elementsAppended + count); + int result = elementsAppended; + putNotNulls(elementsAppended, count); + elementsAppended += count; + return result; + } + + public final int appendBoolean(boolean v) { + reserve(elementsAppended + 1); + putBoolean(elementsAppended, v); + return elementsAppended++; + } + + public final int appendBooleans(int count, boolean v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putBooleans(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendByte(byte v) { + reserve(elementsAppended + 1); + putByte(elementsAppended, v); + return elementsAppended++; + } + + public final int appendBytes(int count, byte v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putBytes(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendBytes(int length, byte[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putBytes(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + + public final int appendShort(short v) { + reserve(elementsAppended + 1); + putShort(elementsAppended, v); + return elementsAppended++; + } + + public final int appendShorts(int count, short v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putShorts(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendShorts(int length, short[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putShorts(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + + public final int appendInt(int v) { + reserve(elementsAppended + 1); + putInt(elementsAppended, v); + return elementsAppended++; + } + + public final int appendInts(int count, int v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putInts(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendInts(int length, int[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putInts(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + + public final int appendLong(long v) { + reserve(elementsAppended + 1); + putLong(elementsAppended, v); + return elementsAppended++; + } + + public final int appendLongs(int count, long v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putLongs(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendLongs(int length, long[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putLongs(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + + public final int appendFloat(float v) { + reserve(elementsAppended + 1); + putFloat(elementsAppended, v); + return elementsAppended++; + } + + public final int appendFloats(int count, float v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putFloats(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendFloats(int length, float[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putFloats(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + + public final int appendDouble(double v) { + reserve(elementsAppended + 1); + putDouble(elementsAppended, v); + return elementsAppended++; + } + + public final int appendDoubles(int count, double v) { + reserve(elementsAppended + count); + int result = elementsAppended; + putDoubles(elementsAppended, count, v); + elementsAppended += count; + return result; + } + + public final int appendDoubles(int length, double[] src, int offset) { + reserve(elementsAppended + length); + int result = elementsAppended; + putDoubles(elementsAppended, length, src, offset); + elementsAppended += length; + return result; + } + + public final int appendByteArray(byte[] value, int offset, int length) { + int copiedOffset = arrayData().appendBytes(length, value, offset); + reserve(elementsAppended + 1); + putArray(elementsAppended, copiedOffset, length); + return elementsAppended++; + } + + public final int appendArray(int length) { + reserve(elementsAppended + 1); + putArray(elementsAppended, arrayData().elementsAppended, length); + return elementsAppended++; + } + + /** + * Appends a NULL struct. This *has* to be used for structs instead of appendNull() as this + * recursively appends a NULL to its children. + * We don't have this logic as the general appendNull implementation to optimize the more + * common non-struct case. + */ + public final int appendStruct(boolean isNull) { + if (isNull) { + appendNull(); + for (ColumnVector c: childColumns) { + if (c.type instanceof StructType) { + ((WritableColumnVector) c).appendStruct(true); + } else { + ((WritableColumnVector) c).appendNull(); + } + } + } else { + appendNotNull(); + } + return elementsAppended; + } + + /** + * Returns the data for the underlying array. + */ + @Override + public WritableColumnVector arrayData() { return childColumns[0]; } + + /** + * Returns the ordinal's child data column. + */ + @Override + public WritableColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } + + /** + * Returns the elements appended. + */ + public final int getElementsAppended() { return elementsAppended; } + + /** + * Marks this column as being constant. + */ + public final void setIsConstant() { isConstant = true; } + + /** + * Maximum number of rows that can be stored in this column. + */ + protected int capacity; + + /** + * Upper limit for the maximum capacity for this column. + */ + @VisibleForTesting + protected int MAX_CAPACITY = Integer.MAX_VALUE; + + /** + * Number of nulls in this column. This is an optimization for the reader, to skip NULL checks. + */ + protected int numNulls; + + /** + * True if there is at least one NULL byte set. This is an optimization for the writer, to skip + * having to clear NULL bits. + */ + protected boolean anyNullsSet; + + /** + * True if this column's values are fixed. This means the column values never change, even + * across resets. + */ + protected boolean isConstant; + + /** + * Default size of each array length value. This grows as necessary. + */ + protected static final int DEFAULT_ARRAY_LENGTH = 4; + + /** + * Current write cursor (row index) when appending data. + */ + protected int elementsAppended; + + /** + * If this is a nested type (array or struct), the column for the child data. + */ + protected WritableColumnVector[] childColumns; + + /** + * Update the dictionary. + */ + public void setDictionary(Dictionary dictionary) { + this.dictionary = dictionary; + } + + /** + * Reserve a integer column for ids of dictionary. + */ + public WritableColumnVector reserveDictionaryIds(int capacity) { + WritableColumnVector dictionaryIds = (WritableColumnVector) this.dictionaryIds; + if (dictionaryIds == null) { + dictionaryIds = reserveNewColumn(capacity, DataTypes.IntegerType); + this.dictionaryIds = dictionaryIds; + } else { + dictionaryIds.reset(); + dictionaryIds.reserve(capacity); + } + return dictionaryIds; + } + + /** + * Returns the underlying integer column for ids of dictionary. + */ + @Override + public WritableColumnVector getDictionaryIds() { + return (WritableColumnVector) dictionaryIds; + } + + /** + * Reserve a new column. + */ + protected abstract WritableColumnVector reserveNewColumn(int capacity, DataType type); + + /** + * Sets up the common state and also handles creating the child columns if this is a nested + * type. + */ + protected WritableColumnVector(int capacity, DataType type) { + super(type); + this.capacity = capacity; + + if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType + || DecimalType.isByteArrayDecimalType(type)) { + DataType childType; + int childCapacity = capacity; + if (type instanceof ArrayType) { + childType = ((ArrayType)type).elementType(); + } else { + childType = DataTypes.ByteType; + childCapacity *= DEFAULT_ARRAY_LENGTH; + } + this.childColumns = new WritableColumnVector[1]; + this.childColumns[0] = reserveNewColumn(childCapacity, childType); + this.resultArray = new ColumnVector.Array(this.childColumns[0]); + this.resultStruct = null; + } else if (type instanceof StructType) { + StructType st = (StructType)type; + this.childColumns = new WritableColumnVector[st.fields().length]; + for (int i = 0; i < childColumns.length; ++i) { + this.childColumns[i] = reserveNewColumn(capacity, st.fields()[i].dataType()); + } + this.resultArray = null; + this.resultStruct = new ColumnarBatch.Row(this.childColumns); + } else if (type instanceof CalendarIntervalType) { + // Two columns. Months as int. Microseconds as Long. + this.childColumns = new WritableColumnVector[2]; + this.childColumns[0] = reserveNewColumn(capacity, DataTypes.IntegerType); + this.childColumns[1] = reserveNewColumn(capacity, DataTypes.LongType); + this.resultArray = null; + this.resultStruct = new ColumnarBatch.Row(this.childColumns); + } else { + this.childColumns = null; + this.resultArray = null; + this.resultStruct = null; + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 0c40417db0837..13f79275cac41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -76,6 +76,8 @@ class VectorizedHashMapGenerator( }.mkString("\n").concat(";") s""" + | private org.apache.spark.sql.execution.vectorized.OnHeapColumnVector[] batchVectors; + | private org.apache.spark.sql.execution.vectorized.OnHeapColumnVector[] bufferVectors; | private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch; | private org.apache.spark.sql.execution.vectorized.ColumnarBatch aggregateBufferBatch; | private int[] buckets; @@ -89,14 +91,19 @@ class VectorizedHashMapGenerator( | $generatedAggBufferSchema | | public $generatedClassName() { - | batch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate(schema, - | org.apache.spark.memory.MemoryMode.ON_HEAP, capacity); - | // TODO: Possibly generate this projection in HashAggregate directly - | aggregateBufferBatch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate( - | aggregateBufferSchema, org.apache.spark.memory.MemoryMode.ON_HEAP, capacity); - | for (int i = 0 ; i < aggregateBufferBatch.numCols(); i++) { - | aggregateBufferBatch.setColumn(i, batch.column(i+${groupingKeys.length})); + | batchVectors = org.apache.spark.sql.execution.vectorized + | .OnHeapColumnVector.allocateColumns(capacity, schema); + | batch = new org.apache.spark.sql.execution.vectorized.ColumnarBatch( + | schema, batchVectors, capacity); + | + | bufferVectors = new org.apache.spark.sql.execution.vectorized + | .OnHeapColumnVector[aggregateBufferSchema.fields().length]; + | for (int i = 0; i < aggregateBufferSchema.fields().length; i++) { + | bufferVectors[i] = batchVectors[i + ${groupingKeys.length}]; | } + | // TODO: Possibly generate this projection in HashAggregate directly + | aggregateBufferBatch = new org.apache.spark.sql.execution.vectorized.ColumnarBatch( + | aggregateBufferSchema, bufferVectors, capacity); | | buckets = new int[numBuckets]; | java.util.Arrays.fill(buckets, -1); @@ -112,8 +119,8 @@ class VectorizedHashMapGenerator( * * {{{ * private boolean equals(int idx, long agg_key, long agg_key1) { - * return batch.column(0).getLong(buckets[idx]) == agg_key && - * batch.column(1).getLong(buckets[idx]) == agg_key1; + * return batchVectors[0].getLong(buckets[idx]) == agg_key && + * batchVectors[1].getLong(buckets[idx]) == agg_key1; * } * }}} */ @@ -121,8 +128,8 @@ class VectorizedHashMapGenerator( def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - s"""(${ctx.genEqual(key.dataType, ctx.getValue("batch", "buckets[idx]", - key.dataType, ordinal), key.name)})""" + s"""(${ctx.genEqual(key.dataType, ctx.getValue(s"batchVectors[$ordinal]", "buckets[idx]", + key.dataType), key.name)})""" }.mkString(" && ") } @@ -150,9 +157,9 @@ class VectorizedHashMapGenerator( * while (step < maxSteps) { * // Return bucket index if it's either an empty slot or already contains the key * if (buckets[idx] == -1) { - * batch.column(0).putLong(numRows, agg_key); - * batch.column(1).putLong(numRows, agg_key1); - * batch.column(2).putLong(numRows, 0); + * batchVectors[0].putLong(numRows, agg_key); + * batchVectors[1].putLong(numRows, agg_key1); + * batchVectors[2].putLong(numRows, 0); * buckets[idx] = numRows++; * return batch.getRow(buckets[idx]); * } else if (equals(idx, agg_key, agg_key1)) { @@ -170,13 +177,13 @@ class VectorizedHashMapGenerator( def genCodeToSetKeys(groupingKeys: Seq[Buffer]): Seq[String] = { groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - ctx.setValue("batch", "numRows", key.dataType, ordinal, key.name) + ctx.setValue(s"batchVectors[$ordinal]", "numRows", key.dataType, key.name) } } def genCodeToSetAggBuffers(bufferValues: Seq[Buffer]): Seq[String] = { bufferValues.zipWithIndex.map { case (key: Buffer, ordinal: Int) => - ctx.updateColumn("batch", "numRows", key.dataType, groupingKeys.length + ordinal, + ctx.updateColumn(s"batchVectors[${groupingKeys.length + ordinal}]", "numRows", key.dataType, buffVars(ordinal), nullable = true) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index 67b3d98c1daed..1331f157363b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -24,7 +24,10 @@ import scala.util.Random import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.vectorized.ColumnVector -import org.apache.spark.sql.types.{BinaryType, IntegerType} +import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector +import org.apache.spark.sql.execution.vectorized.WritableColumnVector +import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType} import org.apache.spark.unsafe.Platform import org.apache.spark.util.Benchmark import org.apache.spark.util.collection.BitSet @@ -34,6 +37,14 @@ import org.apache.spark.util.collection.BitSet */ object ColumnarBatchBenchmark { + def allocate(capacity: Int, dt: DataType, memMode: MemoryMode): WritableColumnVector = { + if (memMode == MemoryMode.OFF_HEAP) { + new OffHeapColumnVector(capacity, dt) + } else { + new OnHeapColumnVector(capacity, dt) + } + } + // This benchmark reads and writes an array of ints. // TODO: there is a big (2x) penalty for a random access API for off heap. // Note: carefully if modifying this code. It's hard to reason about the JIT. @@ -140,7 +151,7 @@ object ColumnarBatchBenchmark { // Access through the column API with on heap memory val columnOnHeap = { i: Int => - val col = ColumnVector.allocate(count, IntegerType, MemoryMode.ON_HEAP) + val col = allocate(count, IntegerType, MemoryMode.ON_HEAP) var sum = 0L for (n <- 0L until iters) { var i = 0 @@ -159,7 +170,7 @@ object ColumnarBatchBenchmark { // Access through the column API with off heap memory def columnOffHeap = { i: Int => { - val col = ColumnVector.allocate(count, IntegerType, MemoryMode.OFF_HEAP) + val col = allocate(count, IntegerType, MemoryMode.OFF_HEAP) var sum = 0L for (n <- 0L until iters) { var i = 0 @@ -178,7 +189,7 @@ object ColumnarBatchBenchmark { // Access by directly getting the buffer backing the column. val columnOffheapDirect = { i: Int => - val col = ColumnVector.allocate(count, IntegerType, MemoryMode.OFF_HEAP) + val col = allocate(count, IntegerType, MemoryMode.OFF_HEAP) var sum = 0L for (n <- 0L until iters) { var addr = col.valuesNativeAddress() @@ -244,7 +255,7 @@ object ColumnarBatchBenchmark { // Adding values by appending, instead of putting. val onHeapAppend = { i: Int => - val col = ColumnVector.allocate(count, IntegerType, MemoryMode.ON_HEAP) + val col = allocate(count, IntegerType, MemoryMode.ON_HEAP) var sum = 0L for (n <- 0L until iters) { var i = 0 @@ -362,7 +373,7 @@ object ColumnarBatchBenchmark { .map(_.getBytes(StandardCharsets.UTF_8)).toArray def column(memoryMode: MemoryMode) = { i: Int => - val column = ColumnVector.allocate(count, BinaryType, memoryMode) + val column = allocate(count, BinaryType, memoryMode) var sum = 0L for (n <- 0L until iters) { var i = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index c8461dcb9dfdb..08ccbd628cf8f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -34,11 +34,20 @@ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types.CalendarInterval class ColumnarBatchSuite extends SparkFunSuite { + + def allocate(capacity: Int, dt: DataType, memMode: MemoryMode): WritableColumnVector = { + if (memMode == MemoryMode.OFF_HEAP) { + new OffHeapColumnVector(capacity, dt) + } else { + new OnHeapColumnVector(capacity, dt) + } + } + test("Null Apis") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { val reference = mutable.ArrayBuffer.empty[Boolean] - val column = ColumnVector.allocate(1024, IntegerType, memMode) + val column = allocate(1024, IntegerType, memMode) var idx = 0 assert(column.anyNullsSet() == false) assert(column.numNulls() == 0) @@ -109,7 +118,7 @@ class ColumnarBatchSuite extends SparkFunSuite { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { val reference = mutable.ArrayBuffer.empty[Byte] - val column = ColumnVector.allocate(1024, ByteType, memMode) + val column = allocate(1024, ByteType, memMode) var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).map(_.toByte).toArray column.appendBytes(2, values, 0) @@ -167,7 +176,7 @@ class ColumnarBatchSuite extends SparkFunSuite { val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Short] - val column = ColumnVector.allocate(1024, ShortType, memMode) + val column = allocate(1024, ShortType, memMode) var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).map(_.toShort).toArray column.appendShorts(2, values, 0) @@ -247,7 +256,7 @@ class ColumnarBatchSuite extends SparkFunSuite { val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Int] - val column = ColumnVector.allocate(1024, IntegerType, memMode) + val column = allocate(1024, IntegerType, memMode) var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).toArray column.appendInts(2, values, 0) @@ -332,7 +341,7 @@ class ColumnarBatchSuite extends SparkFunSuite { val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Long] - val column = ColumnVector.allocate(1024, LongType, memMode) + val column = allocate(1024, LongType, memMode) var values = (10L :: 20L :: 30L :: 40L :: 50L :: Nil).toArray column.appendLongs(2, values, 0) @@ -419,7 +428,7 @@ class ColumnarBatchSuite extends SparkFunSuite { val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Float] - val column = ColumnVector.allocate(1024, FloatType, memMode) + val column = allocate(1024, FloatType, memMode) var values = (.1f :: .2f :: .3f :: .4f :: .5f :: Nil).toArray column.appendFloats(2, values, 0) @@ -510,7 +519,7 @@ class ColumnarBatchSuite extends SparkFunSuite { val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Double] - val column = ColumnVector.allocate(1024, DoubleType, memMode) + val column = allocate(1024, DoubleType, memMode) var values = (.1 :: .2 :: .3 :: .4 :: .5 :: Nil).toArray column.appendDoubles(2, values, 0) @@ -599,7 +608,7 @@ class ColumnarBatchSuite extends SparkFunSuite { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { val reference = mutable.ArrayBuffer.empty[String] - val column = ColumnVector.allocate(6, BinaryType, memMode) + val column = allocate(6, BinaryType, memMode) assert(column.arrayData().elementsAppended == 0) val str = "string" @@ -656,7 +665,7 @@ class ColumnarBatchSuite extends SparkFunSuite { test("Int Array") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { - val column = ColumnVector.allocate(10, new ArrayType(IntegerType, true), memMode) + val column = allocate(10, new ArrayType(IntegerType, true), memMode) // Fill the underlying data with all the arrays back to back. val data = column.arrayData(); @@ -714,43 +723,43 @@ class ColumnarBatchSuite extends SparkFunSuite { (MemoryMode.ON_HEAP :: Nil).foreach { memMode => { val len = 4 - val columnBool = ColumnVector.allocate(len, new ArrayType(BooleanType, false), memMode) + val columnBool = allocate(len, new ArrayType(BooleanType, false), memMode) val boolArray = Array(false, true, false, true) boolArray.zipWithIndex.map { case (v, i) => columnBool.arrayData.putBoolean(i, v) } columnBool.putArray(0, 0, len) assert(columnBool.getArray(0).toBooleanArray === boolArray) - val columnByte = ColumnVector.allocate(len, new ArrayType(ByteType, false), memMode) + val columnByte = allocate(len, new ArrayType(ByteType, false), memMode) val byteArray = Array[Byte](0, 1, 2, 3) byteArray.zipWithIndex.map { case (v, i) => columnByte.arrayData.putByte(i, v) } columnByte.putArray(0, 0, len) assert(columnByte.getArray(0).toByteArray === byteArray) - val columnShort = ColumnVector.allocate(len, new ArrayType(ShortType, false), memMode) + val columnShort = allocate(len, new ArrayType(ShortType, false), memMode) val shortArray = Array[Short](0, 1, 2, 3) shortArray.zipWithIndex.map { case (v, i) => columnShort.arrayData.putShort(i, v) } columnShort.putArray(0, 0, len) assert(columnShort.getArray(0).toShortArray === shortArray) - val columnInt = ColumnVector.allocate(len, new ArrayType(IntegerType, false), memMode) + val columnInt = allocate(len, new ArrayType(IntegerType, false), memMode) val intArray = Array(0, 1, 2, 3) intArray.zipWithIndex.map { case (v, i) => columnInt.arrayData.putInt(i, v) } columnInt.putArray(0, 0, len) assert(columnInt.getArray(0).toIntArray === intArray) - val columnLong = ColumnVector.allocate(len, new ArrayType(LongType, false), memMode) + val columnLong = allocate(len, new ArrayType(LongType, false), memMode) val longArray = Array[Long](0, 1, 2, 3) longArray.zipWithIndex.map { case (v, i) => columnLong.arrayData.putLong(i, v) } columnLong.putArray(0, 0, len) assert(columnLong.getArray(0).toLongArray === longArray) - val columnFloat = ColumnVector.allocate(len, new ArrayType(FloatType, false), memMode) + val columnFloat = allocate(len, new ArrayType(FloatType, false), memMode) val floatArray = Array(0.0F, 1.1F, 2.2F, 3.3F) floatArray.zipWithIndex.map { case (v, i) => columnFloat.arrayData.putFloat(i, v) } columnFloat.putArray(0, 0, len) assert(columnFloat.getArray(0).toFloatArray === floatArray) - val columnDouble = ColumnVector.allocate(len, new ArrayType(DoubleType, false), memMode) + val columnDouble = allocate(len, new ArrayType(DoubleType, false), memMode) val doubleArray = Array(0.0, 1.1, 2.2, 3.3) doubleArray.zipWithIndex.map { case (v, i) => columnDouble.arrayData.putDouble(i, v) } columnDouble.putArray(0, 0, len) @@ -761,7 +770,7 @@ class ColumnarBatchSuite extends SparkFunSuite { test("Struct Column") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { val schema = new StructType().add("int", IntegerType).add("double", DoubleType) - val column = ColumnVector.allocate(1024, schema, memMode) + val column = allocate(1024, schema, memMode) val c1 = column.getChildColumn(0) val c2 = column.getChildColumn(1) @@ -790,7 +799,7 @@ class ColumnarBatchSuite extends SparkFunSuite { test("Nest Array in Array.") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => - val column = ColumnVector.allocate(10, new ArrayType(new ArrayType(IntegerType, true), true), + val column = allocate(10, new ArrayType(new ArrayType(IntegerType, true), true), memMode) val childColumn = column.arrayData() val data = column.arrayData().arrayData() @@ -823,7 +832,7 @@ class ColumnarBatchSuite extends SparkFunSuite { test("Nest Struct in Array.") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => val schema = new StructType().add("int", IntegerType).add("long", LongType) - val column = ColumnVector.allocate(10, new ArrayType(schema, true), memMode) + val column = allocate(10, new ArrayType(schema, true), memMode) val data = column.arrayData() val c0 = data.getChildColumn(0) val c1 = data.getChildColumn(1) @@ -853,7 +862,7 @@ class ColumnarBatchSuite extends SparkFunSuite { val schema = new StructType() .add("int", IntegerType) .add("array", new ArrayType(IntegerType, true)) - val column = ColumnVector.allocate(10, schema, memMode) + val column = allocate(10, schema, memMode) val c0 = column.getChildColumn(0) val c1 = column.getChildColumn(1) c0.putInt(0, 0) @@ -885,7 +894,7 @@ class ColumnarBatchSuite extends SparkFunSuite { val schema = new StructType() .add("int", IntegerType) .add("struct", subSchema) - val column = ColumnVector.allocate(10, schema, memMode) + val column = allocate(10, schema, memMode) val c0 = column.getChildColumn(0) val c1 = column.getChildColumn(1) c0.putInt(0, 0) @@ -918,7 +927,11 @@ class ColumnarBatchSuite extends SparkFunSuite { .add("intCol2", IntegerType) .add("string", BinaryType) - val batch = ColumnarBatch.allocate(schema, memMode) + val capacity = ColumnarBatch.DEFAULT_BATCH_SIZE + val columns = schema.fields.map { field => + allocate(capacity, field.dataType, memMode) + } + val batch = new ColumnarBatch(schema, columns.toArray, ColumnarBatch.DEFAULT_BATCH_SIZE) assert(batch.numCols() == 4) assert(batch.numRows() == 0) assert(batch.numValidRows() == 0) @@ -926,10 +939,10 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(batch.rowIterator().hasNext == false) // Add a row [1, 1.1, NULL] - batch.column(0).putInt(0, 1) - batch.column(1).putDouble(0, 1.1) - batch.column(2).putNull(0) - batch.column(3).putByteArray(0, "Hello".getBytes(StandardCharsets.UTF_8)) + columns(0).putInt(0, 1) + columns(1).putDouble(0, 1.1) + columns(2).putNull(0) + columns(3).putByteArray(0, "Hello".getBytes(StandardCharsets.UTF_8)) batch.setNumRows(1) // Verify the results of the row. @@ -939,12 +952,12 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(batch.rowIterator().hasNext == true) assert(batch.rowIterator().hasNext == true) - assert(batch.column(0).getInt(0) == 1) - assert(batch.column(0).isNullAt(0) == false) - assert(batch.column(1).getDouble(0) == 1.1) - assert(batch.column(1).isNullAt(0) == false) - assert(batch.column(2).isNullAt(0) == true) - assert(batch.column(3).getUTF8String(0).toString == "Hello") + assert(columns(0).getInt(0) == 1) + assert(columns(0).isNullAt(0) == false) + assert(columns(1).getDouble(0) == 1.1) + assert(columns(1).isNullAt(0) == false) + assert(columns(2).isNullAt(0) == true) + assert(columns(3).getUTF8String(0).toString == "Hello") // Verify the iterator works correctly. val it = batch.rowIterator() @@ -955,7 +968,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(row.getDouble(1) == 1.1) assert(row.isNullAt(1) == false) assert(row.isNullAt(2) == true) - assert(batch.column(3).getUTF8String(0).toString == "Hello") + assert(columns(3).getUTF8String(0).toString == "Hello") assert(it.hasNext == false) assert(it.hasNext == false) @@ -972,20 +985,20 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(batch.rowIterator().hasNext == false) // Add rows [NULL, 2.2, 2, "abc"], [3, NULL, 3, ""], [4, 4.4, 4, "world] - batch.column(0).putNull(0) - batch.column(1).putDouble(0, 2.2) - batch.column(2).putInt(0, 2) - batch.column(3).putByteArray(0, "abc".getBytes(StandardCharsets.UTF_8)) - - batch.column(0).putInt(1, 3) - batch.column(1).putNull(1) - batch.column(2).putInt(1, 3) - batch.column(3).putByteArray(1, "".getBytes(StandardCharsets.UTF_8)) - - batch.column(0).putInt(2, 4) - batch.column(1).putDouble(2, 4.4) - batch.column(2).putInt(2, 4) - batch.column(3).putByteArray(2, "world".getBytes(StandardCharsets.UTF_8)) + columns(0).putNull(0) + columns(1).putDouble(0, 2.2) + columns(2).putInt(0, 2) + columns(3).putByteArray(0, "abc".getBytes(StandardCharsets.UTF_8)) + + columns(0).putInt(1, 3) + columns(1).putNull(1) + columns(2).putInt(1, 3) + columns(3).putByteArray(1, "".getBytes(StandardCharsets.UTF_8)) + + columns(0).putInt(2, 4) + columns(1).putDouble(2, 4.4) + columns(2).putInt(2, 4) + columns(3).putByteArray(2, "world".getBytes(StandardCharsets.UTF_8)) batch.setNumRows(3) def rowEquals(x: InternalRow, y: Row): Unit = { @@ -1232,7 +1245,7 @@ class ColumnarBatchSuite extends SparkFunSuite { test("exceeding maximum capacity should throw an error") { (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => - val column = ColumnVector.allocate(1, ByteType, memMode) + val column = allocate(1, ByteType, memMode) column.MAX_CAPACITY = 15 column.appendBytes(5, 0.toByte) // Successfully allocate twice the requested capacity