diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index be68880e49a8..0c82c0333aba 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -153,10 +153,7 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont this.sparkSchema = StructType$.MODULE$.fromString(sparkRequestedSchemaString); this.reader = new ParquetFileReader( configuration, footer.getFileMetaData(), file, blocks, requestedSchema.getColumns()); - // use the blocks from the reader in case some do not match filters and will not be read - for (BlockMetaData block : reader.getRowGroups()) { - this.totalRowCount += block.getRowCount(); - } + this.totalRowCount = reader.getFilteredRecordCount(); // For test purpose. // If the last external accumulator is `NumRowGroupsAccumulator`, the row group number to read @@ -232,10 +229,7 @@ protected void initialize(String path, List columns) throws IOException this.sparkSchema = new ParquetToSparkSchemaConverter(config).convert(requestedSchema); this.reader = new ParquetFileReader( config, footer.getFileMetaData(), file, blocks, requestedSchema.getColumns()); - // use the blocks from the reader in case some do not match filters and will not be read - for (BlockMetaData block : reader.getRowGroups()) { - this.totalRowCount += block.getRowCount(); - } + this.totalRowCount = reader.getFilteredRecordCount(); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 9d38a74a2956..1b159534c8a4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -320,7 +320,7 @@ private void initializeInternal() throws IOException, UnsupportedOperationExcept private void checkEndOfRowGroup() throws IOException { if (rowsReturned != totalCountLoadedSoFar) return; - PageReadStore pages = reader.readNextRowGroup(); + PageReadStore pages = reader.readNextFilteredRowGroup(); if (pages == null) { throw new IOException("expecting more rows but reached last block. Read " + rowsReturned + " out of " + totalRowCount); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 24a1ba124e56..a546538adc56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -25,9 +25,12 @@ import java.time.{LocalDate, LocalDateTime, ZoneId} import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag +import org.apache.hadoop.fs.Path import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Operators} import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.filter2.predicate.Operators.{Column => _, _} +import org.apache.parquet.hadoop.{ParquetFileReader, ParquetInputFormat, ParquetOutputFormat} +import org.apache.parquet.hadoop.util.HadoopInputFile import org.apache.parquet.schema.MessageType import org.apache.spark.{SparkConf, SparkException} @@ -46,7 +49,7 @@ import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.tags.ExtendedSQLTest -import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} +import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils} /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. @@ -1571,6 +1574,66 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } } } + + test("Support Parquet column index") { + // block 1: + // null count min max + // page-0 0 0 99 + // page-1 0 100 199 + // page-2 0 200 299 + // page-3 0 300 399 + // page-4 0 400 449 + // + // block 2: + // null count min max + // page-0 0 450 549 + // page-1 0 550 649 + // page-2 0 650 749 + // page-3 0 750 849 + // page-4 0 850 899 + withTempPath { path => + spark.range(900) + .repartition(1) + .write + .option(ParquetOutputFormat.PAGE_SIZE, "500") + .option(ParquetOutputFormat.BLOCK_SIZE, "2000") + .parquet(path.getCanonicalPath) + + val parquetFile = path.listFiles().filter(_.getName.startsWith("part")).last + val in = HadoopInputFile.fromPath( + new Path(parquetFile.getCanonicalPath), + spark.sessionState.newHadoopConf()) + + Utils.tryWithResource(ParquetFileReader.open(in)) { reader => + val blocks = reader.getFooter.getBlocks + assert(blocks.size() > 1) + val columns = blocks.get(0).getColumns + assert(columns.size() === 1) + val columnIndex = reader.readColumnIndex(columns.get(0)) + assert(columnIndex.getMinValues.size() > 1) + + val rowGroupCnt = blocks.get(0).getRowCount + // Page count = Second page min value - first page min value + val pageCnt = columnIndex.getMinValues.get(1).asLongBuffer().get() - + columnIndex.getMinValues.get(0).asLongBuffer().get() + assert(pageCnt < rowGroupCnt) + Seq(true, false).foreach { columnIndex => + withSQLConf(ParquetInputFormat.COLUMN_INDEX_FILTERING_ENABLED -> s"$columnIndex") { + val df = spark.read.parquet(parquetFile.getCanonicalPath).where("id = 1") + df.collect() + val plan = df.queryExecution.executedPlan + val metrics = plan.collectLeaves().head.metrics + val numOutputRows = metrics("numOutputRows").value + if (columnIndex) { + assert(numOutputRows === pageCnt) + } else { + assert(numOutputRows === rowGroupCnt) + } + } + } + } + } + } } @ExtendedSQLTest