diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ConstantColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ConstantColumnVector.java new file mode 100644 index 0000000000000..134cb05c1265c --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ConstantColumnVector.java @@ -0,0 +1,292 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import java.math.BigDecimal; +import java.math.BigInteger; + +import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * This class adds the constant support to ColumnVector. + * It supports all the types and contains `set` APIs, + * which will set the exact same value to all rows. + * + * Capacity: The vector stores only one copy of the data. + */ +public class ConstantColumnVector extends ColumnVector { + + // The data stored in this ConstantColumnVector, the vector stores only one copy of the data. + private byte nullData; + private byte byteData; + private short shortData; + private int intData; + private long longData; + private float floatData; + private double doubleData; + private UTF8String stringData; + private byte[] byteArrayData; + private ConstantColumnVector[] childData; + private ColumnarArray arrayData; + private ColumnarMap mapData; + + private final int numRows; + + /** + * @param numRows: The number of rows for this ConstantColumnVector + * @param type: The data type of this ConstantColumnVector + */ + public ConstantColumnVector(int numRows, DataType type) { + super(type); + this.numRows = numRows; + + if (type instanceof StructType) { + this.childData = new ConstantColumnVector[((StructType) type).fields().length]; + } else if (type instanceof CalendarIntervalType) { + // Three columns. Months as int. Days as Int. Microseconds as Long. + this.childData = new ConstantColumnVector[3]; + } else { + this.childData = null; + } + } + + @Override + public void close() { + byteArrayData = null; + for (int i = 0; i < childData.length; i++) { + childData[i].close(); + childData[i] = null; + } + childData = null; + arrayData = null; + mapData = null; + } + + @Override + public boolean hasNull() { + return nullData == 1; + } + + @Override + public int numNulls() { + return hasNull() ? numRows : 0; + } + + @Override + public boolean isNullAt(int rowId) { + return nullData == 1; + } + + /** + * Sets all rows as `null` + */ + public void setNull() { + nullData = (byte) 1; + } + + /** + * Sets all rows as not `null` + */ + public void setNotNull() { + nullData = (byte) 0; + } + + @Override + public boolean getBoolean(int rowId) { + return byteData == 1; + } + + /** + * Sets the boolean `value` for all rows + */ + public void setBoolean(boolean value) { + byteData = (byte) ((value) ? 1 : 0); + } + + @Override + public byte getByte(int rowId) { + return byteData; + } + + /** + * Sets the byte `value` for all rows + */ + public void setByte(byte value) { + byteData = value; + } + + @Override + public short getShort(int rowId) { + return shortData; + } + + /** + * Sets the short `value` for all rows + */ + public void setShort(short value) { + shortData = value; + } + + @Override + public int getInt(int rowId) { + return intData; + } + + /** + * Sets the int `value` for all rows + */ + public void setInt(int value) { + intData = value; + } + + @Override + public long getLong(int rowId) { + return longData; + } + + /** + * Sets the long `value` for all rows + */ + public void setLong(long value) { + longData = value; + } + + @Override + public float getFloat(int rowId) { + return floatData; + } + + /** + * Sets the float `value` for all rows + */ + public void setFloat(float value) { + floatData = value; + } + + @Override + public double getDouble(int rowId) { + return doubleData; + } + + /** + * Sets the double `value` for all rows + */ + public void setDouble(double value) { + doubleData = value; + } + + @Override + public ColumnarArray getArray(int rowId) { + return arrayData; + } + + /** + * Sets the `ColumnarArray` `value` for all rows + */ + public void setArray(ColumnarArray value) { + arrayData = value; + } + + @Override + public ColumnarMap getMap(int ordinal) { + return mapData; + } + + /** + * Sets the `ColumnarMap` `value` for all rows + */ + public void setMap(ColumnarMap value) { + mapData = value; + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + // copy and modify from WritableColumnVector + if (precision <= Decimal.MAX_INT_DIGITS()) { + return Decimal.createUnsafe(getInt(rowId), precision, scale); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + return Decimal.createUnsafe(getLong(rowId), precision, scale); + } else { + byte[] bytes = getBinary(rowId); + BigInteger bigInteger = new BigInteger(bytes); + BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); + return Decimal.apply(javaDecimal, precision, scale); + } + } + + /** + * Sets the `Decimal` `value` with the precision for all rows + */ + public void setDecimal(Decimal value, int precision) { + // copy and modify from WritableColumnVector + if (precision <= Decimal.MAX_INT_DIGITS()) { + setInt((int) value.toUnscaledLong()); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + setLong(value.toUnscaledLong()); + } else { + BigInteger bigInteger = value.toJavaBigDecimal().unscaledValue(); + setByteArray(bigInteger.toByteArray()); + } + } + + @Override + public UTF8String getUTF8String(int rowId) { + return stringData; + } + + /** + * Sets the `UTF8String` `value` for all rows + */ + public void setUtf8String(UTF8String value) { + stringData = value; + } + + /** + * Sets the byte array `value` for all rows + */ + private void setByteArray(byte[] value) { + byteArrayData = value; + } + + @Override + public byte[] getBinary(int rowId) { + return byteArrayData; + } + + /** + * Sets the binary `value` for all rows + */ + public void setBinary(byte[] value) { + setByteArray(value); + } + + @Override + public ColumnVector getChild(int ordinal) { + return childData[ordinal]; + } + + /** + * Sets the child `ConstantColumnVector` `value` at the given ordinal for all rows + */ + public void setChild(int ordinal, ConstantColumnVector value) { + childData[ordinal] = value; + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 4bd6c239a3367..443553f6ade03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} import org.apache.spark.sql.execution.datasources.v2.PushedDownOperators import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector +import org.apache.spark.sql.execution.vectorized.ConstantColumnVector import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{BaseRelation, Filter} import org.apache.spark.sql.types.StructType @@ -221,8 +221,8 @@ case class FileSourceScanExec( requiredSchema = requiredSchema, partitionSchema = relation.partitionSchema, relation.sparkSession.sessionState.conf).map { vectorTypes => - // for column-based file format, append metadata struct column's vector type classes if any - vectorTypes ++ Seq.fill(metadataColumns.size)(classOf[OnHeapColumnVector].getName) + // for column-based file format, append metadata column's vector type classes if any + vectorTypes ++ Seq.fill(metadataColumns.size)(classOf[ConstantColumnVector].getName) } private lazy val driverMetrics: HashMap[String, Long] = HashMap.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 5baa597582553..379099ff1db67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericInternalRow, JoinedRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.FileFormat._ -import org.apache.spark.sql.execution.vectorized.{OnHeapColumnVector, WritableColumnVector} +import org.apache.spark.sql.execution.vectorized.ConstantColumnVector import org.apache.spark.sql.types.{LongType, StringType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.unsafe.types.UTF8String @@ -134,7 +134,7 @@ class FileScanRDD( * For each partitioned file, metadata columns for each record in the file are exactly same. * Only update metadata row when `currentFile` is changed. */ - private def updateMetadataRow(): Unit = { + private def updateMetadataRow(): Unit = if (metadataColumns.nonEmpty && currentFile != null) { val path = new Path(currentFile.filePath) metadataColumns.zipWithIndex.foreach { case (attr, i) => @@ -149,44 +149,30 @@ class FileScanRDD( } } } - } /** - * Create a writable column vector containing all required metadata columns + * Create an array of constant column vectors containing all required metadata columns */ - private def createMetadataColumnVector(c: ColumnarBatch): Array[WritableColumnVector] = { + private def createMetadataColumnVector(c: ColumnarBatch): Array[ConstantColumnVector] = { val path = new Path(currentFile.filePath) - val filePathBytes = path.toString.getBytes - val fileNameBytes = path.getName.getBytes - var rowId = 0 metadataColumns.map(_.name).map { case FILE_PATH => - val columnVector = new OnHeapColumnVector(c.numRows(), StringType) - rowId = 0 - // use a tight-loop for better performance - while (rowId < c.numRows()) { - columnVector.putByteArray(rowId, filePathBytes) - rowId += 1 - } + val columnVector = new ConstantColumnVector(c.numRows(), StringType) + columnVector.setUtf8String(UTF8String.fromString(path.toString)) columnVector case FILE_NAME => - val columnVector = new OnHeapColumnVector(c.numRows(), StringType) - rowId = 0 - // use a tight-loop for better performance - while (rowId < c.numRows()) { - columnVector.putByteArray(rowId, fileNameBytes) - rowId += 1 - } + val columnVector = new ConstantColumnVector(c.numRows(), StringType) + columnVector.setUtf8String(UTF8String.fromString(path.getName)) columnVector case FILE_SIZE => - val columnVector = new OnHeapColumnVector(c.numRows(), LongType) - columnVector.putLongs(0, c.numRows(), currentFile.fileSize) + val columnVector = new ConstantColumnVector(c.numRows(), LongType) + columnVector.setLong(currentFile.fileSize) columnVector case FILE_MODIFICATION_TIME => - val columnVector = new OnHeapColumnVector(c.numRows(), LongType) + val columnVector = new ConstantColumnVector(c.numRows(), LongType) // the modificationTime from the file is in millisecond, // while internally, the TimestampType is stored in microsecond - columnVector.putLongs(0, c.numRows(), currentFile.modificationTime * 1000L) + columnVector.setLong(currentFile.modificationTime * 1000L) columnVector }.toArray } @@ -198,10 +184,9 @@ class FileScanRDD( private def addMetadataColumnsIfNeeded(nextElement: Object): Object = { if (metadataColumns.nonEmpty) { nextElement match { - case c: ColumnarBatch => - new ColumnarBatch( - Array.tabulate(c.numCols())(c.column) ++ createMetadataColumnVector(c), - c.numRows()) + case c: ColumnarBatch => new ColumnarBatch( + Array.tabulate(c.numCols())(c.column) ++ createMetadataColumnVector(c), + c.numRows()) case u: UnsafeRow => projection.apply(new JoinedRow(u, metadataRow)) case i: InternalRow => new JoinedRow(i, metadataRow) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ConstantColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ConstantColumnVectorSuite.scala new file mode 100644 index 0000000000000..c8438f342d256 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ConstantColumnVectorSuite.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.vectorized + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ColumnarArray, ColumnarMap} +import org.apache.spark.unsafe.types.UTF8String + +class ConstantColumnVectorSuite extends SparkFunSuite { + + private def testVector(name: String, size: Int, dt: DataType) + (f: ConstantColumnVector => Unit): Unit = { + test(name) { + f(new ConstantColumnVector(size, dt)) + } + } + + testVector("null", 10, IntegerType) { vector => + vector.setNull() + assert(vector.hasNull) + assert(vector.numNulls() == 10) + (0 until 10).foreach { i => + assert(vector.isNullAt(i)) + } + + vector.setNotNull() + assert(!vector.hasNull) + assert(vector.numNulls() == 0) + (0 until 10).foreach { i => + assert(!vector.isNullAt(i)) + } + } + + testVector("boolean", 10, BooleanType) { vector => + vector.setBoolean(true) + (0 until 10).foreach { i => + assert(vector.getBoolean(i)) + } + } + + testVector("byte", 10, ByteType) { vector => + vector.setByte(3.toByte) + (0 until 10).foreach { i => + assert(vector.getByte(i) == 3.toByte) + } + } + + testVector("short", 10, ShortType) { vector => + vector.setShort(3.toShort) + (0 until 10).foreach { i => + assert(vector.getShort(i) == 3.toShort) + } + } + + testVector("int", 10, IntegerType) { vector => + vector.setInt(3) + (0 until 10).foreach { i => + assert(vector.getInt(i) == 3) + } + } + + testVector("long", 10, LongType) { vector => + vector.setLong(3L) + (0 until 10).foreach { i => + assert(vector.getLong(i) == 3L) + } + } + + testVector("float", 10, FloatType) { vector => + vector.setFloat(3.toFloat) + (0 until 10).foreach { i => + assert(vector.getFloat(i) == 3.toFloat) + } + } + + testVector("double", 10, DoubleType) { vector => + vector.setDouble(3.toDouble) + (0 until 10).foreach { i => + assert(vector.getDouble(i) == 3.toDouble) + } + } + + testVector("array", 10, ArrayType(IntegerType)) { vector => + // create an vector with constant array: [0, 1, 2, 3, 4] + val arrayVector = new OnHeapColumnVector(5, IntegerType) + (0 until 5).foreach { i => + arrayVector.putInt(i, i) + } + val columnarArray = new ColumnarArray(arrayVector, 0, 5) + + vector.setArray(columnarArray) + + (0 until 10).foreach { i => + assert(vector.getArray(i) == columnarArray) + assert(vector.getArray(i).toIntArray === Array(0, 1, 2, 3, 4)) + } + } + + testVector("map", 10, MapType(IntegerType, BooleanType)) { vector => + // create an vector with constant map: + // [(0, true), (1, false), (2, true), (3, false), (4, true)] + val keys = new OnHeapColumnVector(5, IntegerType) + val values = new OnHeapColumnVector(5, BooleanType) + + (0 until 5).foreach { i => + keys.putInt(i, i) + values.putBoolean(i, i % 2 == 0) + } + + val columnarMap = new ColumnarMap(keys, values, 0, 5) + vector.setMap(columnarMap) + + (0 until 10).foreach { i => + assert(vector.getMap(i) == columnarMap) + assert(vector.getMap(i).keyArray().toIntArray === Array(0, 1, 2, 3, 4)) + assert(vector.getMap(i).valueArray().toBooleanArray === + Array(true, false, true, false, true)) + } + } + + testVector("decimal", 10, DecimalType(10, 0)) { vector => + val decimal = Decimal(100L) + vector.setDecimal(decimal, 10) + (0 until 10).foreach { i => + assert(vector.getDecimal(i, 10, 0) == decimal) + } + } + + testVector("utf8string", 10, StringType) { vector => + vector.setUtf8String(UTF8String.fromString("hello")) + (0 until 10).foreach { i => + assert(vector.getUTF8String(i) == UTF8String.fromString("hello")) + } + } + + testVector("binary", 10, BinaryType) { vector => + vector.setBinary("hello".getBytes("utf8")) + (0 until 10).foreach { i => + assert(vector.getBinary(i) === "hello".getBytes("utf8")) + } + } + + testVector("struct", 10, + new StructType() + .add(StructField("name", StringType)) + .add(StructField("age", IntegerType))) { vector => + + val nameVector = new ConstantColumnVector(10, StringType) + nameVector.setUtf8String(UTF8String.fromString("jack")) + vector.setChild(0, nameVector) + + val ageVector = new ConstantColumnVector(10, IntegerType) + ageVector.setInt(27) + vector.setChild(1, ageVector) + + + assert(vector.getChild(0) == nameVector) + assert(vector.getChild(1) == ageVector) + (0 until 10).foreach { i => + assert(vector.getChild(0).getUTF8String(i) == UTF8String.fromString("jack")) + assert(vector.getChild(1).getInt(i) == 27) + } + + // another API + (0 until 10).foreach { i => + assert(vector.getStruct(i).get(0, StringType) == UTF8String.fromString("jack")) + assert(vector.getStruct(i).get(1, IntegerType) == 27) + } + } + + testVector("calendar interval", 10, CalendarIntervalType) { vector => + val monthsVector = new ConstantColumnVector(10, IntegerType) + monthsVector.setInt(3) + val daysVector = new ConstantColumnVector(10, IntegerType) + daysVector.setInt(25) + val microsecondsVector = new ConstantColumnVector(10, LongType) + microsecondsVector.setLong(12345L) + + vector.setChild(0, monthsVector) + vector.setChild(1, daysVector) + vector.setChild(2, microsecondsVector) + + (0 until 10).foreach { i => + assert(vector.getChild(0).getInt(i) == 3) + assert(vector.getChild(1).getInt(i) == 25) + assert(vector.getChild(2).getLong(i) == 12345L) + } + } +}