diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 672b73e94c42..8dc167badd99 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -21,7 +21,7 @@ import java.math.BigInteger; import java.time.ZoneId; import java.time.ZoneOffset; -import java.util.Arrays; +import java.util.*; import org.apache.parquet.bytes.ByteBufferInputStream; import org.apache.parquet.bytes.BytesInput; @@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.util.RebaseDateTime; import org.apache.spark.sql.execution.datasources.DataSourceUtils; import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException; +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; @@ -111,6 +112,12 @@ public class VectorizedColumnReader { private final String datetimeRebaseMode; private final String int96RebaseMode; + // TODO handle and init these filed properly + private PrimitiveIterator.OfLong rowIndexesIterator; + private long[] rowIndexes; // row indexes of current row group + private long currentRow = 0; // current row to read + private WritableColumnVector tempColumnVector; + private boolean isDecimalTypeMatched(DataType dt) { DecimalType d = (DecimalType) dt; DecimalMetadata dm = descriptor.getPrimitiveType().getDecimalMetadata(); @@ -140,7 +147,10 @@ public VectorizedColumnReader( PageReader pageReader, ZoneId convertTz, String datetimeRebaseMode, - String int96RebaseMode) throws IOException { + String int96RebaseMode, + PrimitiveIterator.OfLong rowIndexesIterator + ) throws IOException { + this.rowIndexesIterator = rowIndexesIterator; this.descriptor = descriptor; this.pageReader = pageReader; this.convertTz = convertTz; @@ -248,15 +258,48 @@ static long rebaseInt96(long julianMicros, final boolean failIfRebase) { /** * Reads `total` values from this columnReader into column. */ - void readBatch(int total, WritableColumnVector column) throws IOException { + void readBatch(int total, int columnSize, WritableColumnVector resultColumn) throws IOException { + PrimitiveType.PrimitiveTypeName typeName = + descriptor.getPrimitiveType().getPrimitiveTypeName(); + + WritableColumnVector column; + + if (rowIndexesIterator != null) { + if (tempColumnVector == null) { + tempColumnVector = new OnHeapColumnVector(columnSize, resultColumn.dataType()); + } + column = tempColumnVector; + column.reset(); + + rowIndexes = new long[total]; + for (int i = 0; i < total; i++) { + rowIndexes[i] = rowIndexesIterator.next(); + } + + // if row indexes is exactly matching the range we are going to read + // there is no need to do additional row index synchronization + boolean continuousRange = (rowIndexes[total - 1] - rowIndexes[0] + 1) == total; + if (continuousRange && rowIndexes[0] == currentRow) { + column = resultColumn; + } + } else { + // write to result column directly if no row indexes if present + column = resultColumn; + } + int rowId = 0; WritableColumnVector dictionaryIds = null; + WritableColumnVector resultDictionaryIds = null; if (dictionary != null) { // SPARK-16334: We only maintain a single dictionary per row batch, so that it can be used to // decode all previous dictionary encoded pages if we ever encounter a non-dictionary encoded // page. dictionaryIds = column.reserveDictionaryIds(total); + if (column != resultColumn) { + resultDictionaryIds = resultColumn.reserveDictionaryIds(total); + } } + while (total > 0) { // Compute the number of values we want to read in this page. int leftInPage = (int) (endOfPageValueCount - valuesRead); @@ -265,8 +308,6 @@ void readBatch(int total, WritableColumnVector column) throws IOException { leftInPage = (int) (endOfPageValueCount - valuesRead); } int num = Math.min(total, leftInPage); - PrimitiveType.PrimitiveTypeName typeName = - descriptor.getPrimitiveType().getPrimitiveTypeName(); if (isCurrentPageDictionaryEncoded) { // Read and decode dictionary ids. defColumn.readIntegers( @@ -297,6 +338,9 @@ void readBatch(int total, WritableColumnVector column) throws IOException { boolean needTransform = castLongToInt || isUnsignedInt32 || isUnsignedInt64; column.setDictionary(new ParquetDictionary(dictionary, needTransform)); + if (column != resultColumn) { // set result column as well + resultColumn.setDictionary(new ParquetDictionary(dictionary, needTransform)); + } } else { decodeDictionaryIds(rowId, num, column, dictionaryIds); } @@ -304,7 +348,11 @@ void readBatch(int total, WritableColumnVector column) throws IOException { if (column.hasDictionary() && rowId != 0) { // This batch already has dictionary encoded values but this new page is not. The batch // does not support a mix of dictionary and not so we will decode the dictionary. - decodeDictionaryIds(0, rowId, column, column.getDictionaryIds()); + if (column != resultColumn) { + decodeDictionaryIds(0, rowId, resultColumn, resultColumn.getDictionaryIds()); + } else { + decodeDictionaryIds(0, rowId, column, column.getDictionaryIds()); + } } column.setDictionary(null); switch (typeName) { @@ -338,9 +386,87 @@ void readBatch(int total, WritableColumnVector column) throws IOException { } } + if (resultColumn != column) { + // copy values from temp column to result column + boolean continuousRange = (rowIndexes[total - 1] - rowIndexes[0] + 1) == total; + if (continuousRange) { + // skip to offset pos and dump all remaining values + int offset = (int) (rowIndexes[rowId] - currentRow); + if (offset < num) { + int validValueNum = num - offset; + if (isCurrentPageDictionaryEncoded && column.hasDictionary()) { + resultDictionaryIds.putInts(rowId, validValueNum, dictionaryIds.getInts(rowId + offset, validValueNum), 0); + } else if (resultColumn.dataType() == DataTypes.ByteType) { + resultColumn.putBytes(rowId, validValueNum, column.getBytes(rowId + offset, validValueNum), 0); + } else if (resultColumn.dataType() == DataTypes.ShortType) { + resultColumn.putShorts(rowId, validValueNum, column.getShorts(rowId + offset, validValueNum), 0); + } else if (resultColumn.dataType() == DataTypes.IntegerType) { + resultColumn.putInts(rowId, validValueNum, column.getInts(rowId + offset, validValueNum), 0); + } else if (resultColumn.dataType() == DataTypes.LongType) { + resultColumn.putLongs(rowId, validValueNum, column.getLongs(rowId + offset, validValueNum), 0); + } else if (resultColumn.dataType() == DataTypes.FloatType) { + resultColumn.putFloats(rowId, validValueNum, column.getFloats(rowId + offset, validValueNum), 0); + } else if (resultColumn.dataType() == DataTypes.DoubleType) { + resultColumn.putDoubles(rowId, validValueNum, column.getDoubles(rowId + offset, validValueNum), 0); + } else if (resultColumn.dataType() == DataTypes.DateType) { + resultColumn.putInts(rowId, validValueNum, column.getInts(rowId + offset, validValueNum), 0); + } else if (resultColumn.dataType() == DataTypes.TimestampType) { + resultColumn.putLongs(rowId, validValueNum, column.getLongs(rowId + offset, validValueNum), 0); + } else if (resultColumn.dataType() == DataTypes.BooleanType) { + for (int i = 0; i < validValueNum; i++) { + resultColumn.putBoolean(rowId + i, column.getBoolean(rowId + offset + i)); + } + } else { + for (int i = 0; i < validValueNum; i++) { + resultColumn.putByteArray(rowId + i, column.getBinary(rowId + offset + i)); + } + } + rowId += validValueNum; + total -= validValueNum; + } + } else { + // need to check every row + for (int i = 0, startingRowId = rowId; i < num && total > 0; ) { + while (currentRow + i < rowIndexes[rowId] && i < num) { + i++; + } + if (i >= num) { + break; + } + if (isCurrentPageDictionaryEncoded && column.hasDictionary()) { + resultDictionaryIds.putInt(rowId, dictionaryIds.getInt(startingRowId + i)); + } if (resultColumn.dataType() == DataTypes.ByteType) { + resultColumn.putByte(rowId, column.getByte(startingRowId + i)); + } else if (resultColumn.dataType() == DataTypes.ShortType) { + resultColumn.putShort(rowId, column.getShort(startingRowId + i)); + } else if (resultColumn.dataType() == DataTypes.IntegerType) { + resultColumn.putInt(rowId, column.getInt(startingRowId + i)); + } else if (resultColumn.dataType() == DataTypes.LongType) { + resultColumn.putLong(rowId, column.getLong(startingRowId + i)); + } else if (resultColumn.dataType() == DataTypes.DateType) { + resultColumn.putInt(rowId, column.getInt(startingRowId + i)); + } else if (resultColumn.dataType() == DataTypes.FloatType) { + resultColumn.putFloat(rowId, column.getFloat(startingRowId + i)); + } else if (resultColumn.dataType() == DataTypes.DoubleType) { + resultColumn.putDouble(rowId, column.getDouble(startingRowId + i)); + } else if (resultColumn.dataType() == DataTypes.TimestampType) { + resultColumn.putLong(rowId, column.getLong(startingRowId + i)); + } else if (resultColumn.dataType() == DataTypes.BooleanType) { + resultColumn.putBoolean(rowId, column.getBoolean(startingRowId + i)); + } else { + resultColumn.putByteArray(rowId, column.getBinary(startingRowId + i)); + } + rowId++; + total--; + } + } + } else { + rowId += num; + total -= num; + } + + currentRow += num; valuesRead += num; - rowId += num; - total -= num; } } @@ -853,6 +979,7 @@ private void initDataReader(Encoding dataEncoding, ByteBufferInputStream in) thr } private void readPageV1(DataPageV1 page) throws IOException { + this.currentRow = page.getFirstRowIndex().orElse(0L); this.pageValueCount = page.getValueCount(); ValuesReader rlReader = page.getRlEncoding().getValuesReader(descriptor, REPETITION_LEVEL); ValuesReader dlReader; @@ -878,6 +1005,7 @@ private void readPageV1(DataPageV1 page) throws IOException { } private void readPageV2(DataPageV2 page) throws IOException { + this.currentRow = page.getFirstRowIndex().orElse(0L); this.pageValueCount = page.getValueCount(); this.repetitionLevelColumn = createRLEIterator(descriptor.getMaxRepetitionLevel(), page.getRepetitionLevels(), descriptor); @@ -894,4 +1022,5 @@ private void readPageV2(DataPageV2 page) throws IOException { throw new IOException("could not read page " + page + " in col " + descriptor, e); } } + } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 1b159534c8a4..68606ea29990 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -280,7 +280,7 @@ public boolean nextBatch() throws IOException { int num = (int) Math.min((long) capacity, totalCountLoadedSoFar - rowsReturned); for (int i = 0; i < columnReaders.length; ++i) { if (columnReaders[i] == null) continue; - columnReaders[i].readBatch(num, columnVectors[i]); + columnReaders[i].readBatch(num, capacity, columnVectors[i]); } rowsReturned += num; columnarBatch.setNumRows(num); @@ -336,7 +336,8 @@ private void checkEndOfRowGroup() throws IOException { pages.getPageReader(columns.get(i)), convertTz, datetimeRebaseMode, - int96RebaseMode); + int96RebaseMode, + pages.getRowIndexes().orElse(null)); } totalCountLoadedSoFar += pages.getRowCount(); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ParquetColumnIndexBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ParquetColumnIndexBenchmark.scala new file mode 100644 index 000000000000..041ece7da4d4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/ParquetColumnIndexBenchmark.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import java.io.File + +import scala.util.Random + +import org.apache.parquet.hadoop.ParquetInputFormat + +import org.apache.spark.SparkConf +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.{DataFrame, SparkSession} + +/** + * Benchmark to measure read performance with Parquet column index. + * To run this benchmark: + * {{{ + * 1. without sbt: bin/spark-submit --class + * 2. build/sbt "sql/test:runMain " + * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/ParquetFilterPushdownBenchmark-results.txt". + * }}} + */ +object ParquetColumnIndexBenchmark extends SqlBasedBenchmark { + + override def getSparkSession: SparkSession = { + val conf = new SparkConf() + .setAppName(this.getClass.getSimpleName) + // Since `spark.master` always exists, overrides this value + .set("spark.master", "local[1]") + .setIfMissing("spark.driver.memory", "3g") + .setIfMissing("spark.executor.memory", "3g") + .setIfMissing("orc.compression", "snappy") + .setIfMissing("spark.sql.parquet.compression.codec", "snappy") + + SparkSession.builder().config(conf).getOrCreate() + } + + private val numRows = 1024 * 1024 * 15 + private val width = 5 + private val mid = numRows / 2 + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(spark.catalog.dropTempView) + } + + private def prepareTable( + dir: File, numRows: Int): Unit = { + import spark.implicits._ + + val df = spark.range(numRows).map(i => (i, i + ":f" + "o" * Random.nextInt(200))).toDF() + + saveAsTable(df, dir) + } + + private def saveAsTable(df: DataFrame, dir: File, useDictionary: Boolean = false): Unit = { + val parquetPath = dir.getCanonicalPath + "/parquet" + df.write.mode("overwrite").parquet(parquetPath) + spark.read.parquet(parquetPath).createOrReplaceTempView("parquetTable") + } + + def filterPushDownBenchmark( + values: Int, + title: String, + whereExpr: String, + selectExpr: String = "*"): Unit = { + val benchmark = new Benchmark(title, values, minNumIters = 5, output = output) + + Seq(false, true).foreach { columnIndexEnabled => + val name = s"Parquet Vectorized ${if (columnIndexEnabled) s"(columnIndex)" else ""}" + benchmark.addCase(name) { _ => + withSQLConf(ParquetInputFormat.COLUMN_INDEX_FILTERING_ENABLED -> s"$columnIndexEnabled") { + spark.sql(s"SELECT $selectExpr FROM parquetTable WHERE $whereExpr").noop() + } + } + } + + benchmark.run() + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runBenchmark("Pushdown for single value filter") { + withTempPath { dir => + withTempTable("parquetTable") { + prepareTable(dir, numRows) + filterPushDownBenchmark(numRows, "simple filters", s" _1 = $numRows - 100 ") + } + } + } + + runBenchmark("Pushdown for range filter") { + withTempPath { dir => + withTempTable("parquetTable") { + prepareTable(dir, numRows) + filterPushDownBenchmark(numRows, + "range filters", s" _1 > ($numRows - 1000000) and _1 < ($numRows - 1000)") + } + } + } + + runBenchmark("Pushdown for multi range filter") { + withTempPath { dir => + withTempTable("parquetTable") { + prepareTable(dir, numRows) + filterPushDownBenchmark(numRows, + "multi range filters", + s" (_1 > ($numRows - 3000000) and _1 < ($numRows - 2000000))" + + s" or ( _1 > ($numRows - 1000000) and _1 < ($numRows - 1000))") + } + } + } + + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala new file mode 100644 index 000000000000..44bb93ff9c35 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql._ +import org.apache.spark.sql.test.SharedSparkSession + + +class ParquetColumnIndexSuite extends QueryTest with ParquetTest with SharedSparkSession { + + import testImplicits._ + + /** + * create parquet file with two columns and unaligned pages + * pages will be of the following layout + * col_1 500 500 500 500 + * |---------|---------|---------|---------| + * |-------|-----|-----|---|---|---|---|---| + * col_2 400 300 200 200 200 200 200 200 + */ + def checkUnalignedPages(actions: (DataFrame => DataFrame)*): Unit = { + withTempPath(file => { + val ds = spark.range(0, 2000).map(i => (i, i + ":" + "o" * (i / 100).toInt)) + ds.coalesce(1) + .write + .option("parquet.page.size", "4096") + .parquet(file.getCanonicalPath) + + val parquetDf = spark.read.parquet(file.getCanonicalPath) + + actions.foreach{ action => + checkAnswer(action(parquetDf), action(ds.toDF())) + } + }) + } + + test("reading from unaligned pages - test filters") { + checkUnalignedPages( + // single value filter + df => df.filter("_1 = 500"), + df => df.filter("_1 = 500 or _1 = 1500"), + df => df.filter("_1 = 500 or _1 = 501 or _1 = 1500"), + df => df.filter("_1 = 500 or _1 = 501 or _1 = 1000 or _1 = 1500"), + // range filter + df => df.filter("_1 >= 500 and _1 < 1000"), + df => df.filter("(_1 >= 500 and _1 < 1000) or (_1 >= 1500 and _1 < 1600)") + ) + } + + test("test reading unaligned pages - test all types") { + withTempPath(file => { + val df = spark.range(0, 2000).selectExpr( + "id as _1", + "cast(id as short) as _3", + "cast(id as int) as _4", + "cast(id as float) as _5", + "cast(id as double) as _6", + "cast(id as decimal(20,0)) as _7", + "cast(cast(1618161925000 + id * 1000 * 60 * 60 * 24 as timestamp) as date) as _9", + "cast(1618161925000 + id as timestamp) as _10" + ) + df.coalesce(1) + .write + .option("parquet.page.size", "4096") + .parquet(file.getCanonicalPath) + + val parquetDf = spark.read.parquet(file.getCanonicalPath) + val singleValueFilterExpr = "_1 = 500 or _1 = 1500" + checkAnswer( + parquetDf.filter(singleValueFilterExpr), + df.filter(singleValueFilterExpr) + ) + val rangeFilterExpr = "_1 > 500 " + checkAnswer( + parquetDf.filter(rangeFilterExpr), + df.filter(rangeFilterExpr) + ) + }) + } + + test("test reading unaligned pages - test all types (dict encode)") { + withTempPath(file => { + val df = spark.range(0, 2000).selectExpr( + "id as _1", + "cast(id % 10 as byte) as _2", + "cast(id % 10 as short) as _3", + "cast(id % 10 as int) as _4", + "cast(id % 10 as float) as _5", + "cast(id % 10 as double) as _6", + "cast(id % 10 as decimal(20,0)) as _7", + "cast(id % 2 as boolean) as _8", + "cast(cast(1618161925000 + (id % 10) * 1000 * 60 * 60 * 24 as timestamp) as date) as _9", + "cast(1618161925000 + (id % 10) as timestamp) as _10" + ) + df.coalesce(1) + .write + .option("parquet.page.size", "4096") + .parquet(file.getCanonicalPath) + + val parquetDf = spark.read.parquet(file.getCanonicalPath) + val singleValueFilterExpr = "_1 = 500 or _1 = 1500" + checkAnswer( + parquetDf.filter(singleValueFilterExpr), + df.filter(singleValueFilterExpr) + ) + val rangeFilterExpr = "_1 > 500 " + checkAnswer( + parquetDf.filter(rangeFilterExpr), + df.filter(rangeFilterExpr) + ) + }) + } +}