diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a050156518c2c..4705c27266dfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -901,6 +901,14 @@ object SQLConf { .booleanConf .createWithDefault(false) + val PARQUET_LIMIT_PUSHDOWN_ENABLED = + buildConf("spark.sql.parquet.limitPushdown.enabled") + .doc("Enables Parquet limit push-down optimization when set to true.") + .version("3.3.0") + .internal() + .booleanConf + .createWithDefault(true) + val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat") .doc("If true, data will be written in a way of Spark 1.4 and earlier. For example, decimal " + "values will be written in Apache Parquet's fixed-length byte array format, which other " + @@ -3838,6 +3846,8 @@ class SQLConf extends Serializable with Logging { def parquetAggregatePushDown: Boolean = getConf(PARQUET_AGGREGATE_PUSHDOWN_ENABLED) + def parquetLimitPushDownEnabled: Boolean = getConf(PARQUET_LIMIT_PUSHDOWN_ENABLED) + def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) def orcAggregatePushDown: Boolean = getConf(ORC_AGGREGATE_PUSHDOWN_ENABLED) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 0e976be2f652e..58ee88c22d744 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -56,6 +56,9 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa // The capacity of vectorized batch. private int capacity; + // The pushed down Limit to read + private int limit; + /** * Batch of rows that we assemble and the current index we've returned. Every time this * batch is used up (batchIdx == numBatched), we populated the batch. @@ -139,7 +142,8 @@ public VectorizedParquetRecordReader( String int96RebaseMode, String int96RebaseTz, boolean useOffHeap, - int capacity) { + int capacity, + int limit) { this.convertTz = convertTz; this.datetimeRebaseMode = datetimeRebaseMode; this.datetimeRebaseTz = datetimeRebaseTz; @@ -147,6 +151,19 @@ public VectorizedParquetRecordReader( this.int96RebaseTz = int96RebaseTz; MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP; this.capacity = capacity; + this.limit = limit; + } + + public VectorizedParquetRecordReader( + ZoneId convertTz, + String datetimeRebaseMode, + String datetimeRebaseTz, + String int96RebaseMode, + String int96RebaseTz, + boolean useOffHeap, + int capacity) { + this(convertTz, datetimeRebaseMode, datetimeRebaseTz, int96RebaseMode, int96RebaseTz, + useOffHeap, capacity, Integer.MAX_VALUE); } // For test only. @@ -302,10 +319,13 @@ public boolean nextBatch() throws IOException { vector.reset(); } columnarBatch.setNumRows(0); - if (rowsReturned >= totalRowCount) return false; + if (rowsReturned >= totalRowCount || rowsReturned >= limit) { + return false; + } checkEndOfRowGroup(); - int num = (int) Math.min(capacity, totalCountLoadedSoFar - rowsReturned); + int num = (int) Math.min(capacity, + Math.min(limit - rowsReturned, totalCountLoadedSoFar - rowsReturned)); for (int i = 0; i < columnReaders.length; ++i) { if (columnReaders[i] == null) continue; columnReaders[i].readBatch(num, columnVectors[i]); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala index 2dc4137d6f9a1..89bcc050c58ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala @@ -20,7 +20,7 @@ import scala.collection.mutable import org.apache.spark.sql.{sources, SparkSession} import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.connector.read.{ScanBuilder, SupportsPushDownRequiredColumns} +import org.apache.spark.sql.connector.read.{ScanBuilder, SupportsPushDownLimit, SupportsPushDownRequiredColumns} import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, DataSourceUtils, PartitioningAwareFileIndex, PartitioningUtils} import org.apache.spark.sql.internal.connector.SupportsPushDownCatalystFilters import org.apache.spark.sql.sources.Filter @@ -32,7 +32,8 @@ abstract class FileScanBuilder( dataSchema: StructType) extends ScanBuilder with SupportsPushDownRequiredColumns - with SupportsPushDownCatalystFilters { + with SupportsPushDownCatalystFilters + with SupportsPushDownLimit { private val partitionSchema = fileIndex.partitionSchema private val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis protected val supportsNestedSchemaPruning = false @@ -40,6 +41,7 @@ abstract class FileScanBuilder( protected var partitionFilters = Seq.empty[Expression] protected var dataFilters = Seq.empty[Expression] protected var pushedDataFilters = Array.empty[Filter] + protected var pushedLimit: Option[Int] = None override def pruneColumns(requiredSchema: StructType): Unit = { // [SPARK-30107] While `requiredSchema` might have pruned nested columns, @@ -84,6 +86,8 @@ abstract class FileScanBuilder( dataFilters } + override def pushLimit(limit: Int): Boolean = false + override def pushedFilters: Array[Filter] = pushedDataFilters /* diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index 12b8a631196ae..0b86c5105619d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -56,6 +56,7 @@ import org.apache.spark.util.SerializableConfiguration * @param partitionSchema Schema of partitions. * @param filters Filters to be pushed down in the batch scan. * @param aggregation Aggregation to be pushed down in the batch scan. + * @param limit Limit to be pushed down in the batch scan. * @param parquetOptions The options of Parquet datasource that are set for the read. */ case class ParquetPartitionReaderFactory( @@ -66,6 +67,7 @@ case class ParquetPartitionReaderFactory( partitionSchema: StructType, filters: Array[Filter], aggregation: Option[Aggregation], + limit: Option[Int], parquetOptions: ParquetOptions) extends FilePartitionReaderFactory with Logging { private val isCaseSensitive = sqlConf.caseSensitiveAnalysis private val resultSchema = StructType(partitionSchema.fields ++ readDataSchema.fields) @@ -321,7 +323,8 @@ case class ParquetPartitionReaderFactory( int96RebaseSpec.mode.toString, int96RebaseSpec.timeZone, enableOffHeapColumnVector && taskContext.isDefined, - capacity) + capacity, + limit.getOrElse(Int.MaxValue)) val iter = new RecordReaderIterator(vectorizedReader) // SPARK-23457 Register a task completion listener before `initialization`. taskContext.foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index 6b35f2406a82f..be98857d0be2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -46,7 +46,8 @@ case class ParquetScan( options: CaseInsensitiveStringMap, pushedAggregate: Option[Aggregation] = None, partitionFilters: Seq[Expression] = Seq.empty, - dataFilters: Seq[Expression] = Seq.empty) extends FileScan { + dataFilters: Seq[Expression] = Seq.empty, + pushedLimit: Option[Int] = None) extends FileScan { override def isSplitable(path: Path): Boolean = { // If aggregate is pushed down, only the file footer will be read once, // so file should not be split across multiple tasks. @@ -97,6 +98,7 @@ case class ParquetScan( readPartitionSchema, pushedFilters, pushedAggregate, + pushedLimit, new ParquetOptions(options.asCaseSensitiveMap.asScala.toMap, sqlConf)) } @@ -124,7 +126,8 @@ case class ParquetScan( override def description(): String = { super.description() + ", PushedFilters: " + seqToString(pushedFilters) + ", PushedAggregation: " + pushedAggregationsStr + - ", PushedGroupBy: " + pushedGroupByStr + ", PushedGroupBy: " + pushedGroupByStr + + ", PushedLimit: " + pushedLimit } override def getMetaData(): Map[String, String] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index 1f2f75aebd7bf..dd51ba90a0a99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -105,6 +105,17 @@ case class ParquetScanBuilder( } } + override def pushLimit(limit: Int): Boolean = { + val sqlConf = sparkSession.sessionState.conf + // TODO: Support limit push down for row based parquet reader + if (sqlConf.parquetVectorizedReaderEnabled && sqlConf.parquetLimitPushDownEnabled) { + pushedLimit = Some(limit) + true + } else { + false + } + } + override def build(): Scan = { // the `finalSchema` is either pruned in pushAggregation (if aggregates are // pushed down), or pruned in readDataSchema() (in regular column pruning). These @@ -114,6 +125,6 @@ case class ParquetScanBuilder( } ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, finalSchema, readPartitionSchema(), pushedParquetFilters, options, pushedAggregations, - partitionFilters, dataFilters) + partitionFilters, dataFilters, pushedLimit) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetLimitPushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetLimitPushDownSuite.scala new file mode 100644 index 0000000000000..c45610cb4e7ea --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetLimitPushDownSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import scala.util.Random + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.execution.CollectLimitExec +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +/** + * A test suite that tests Parquet based limit pushdown optimization. + */ +class ParquetLimitPushDownSuite extends QueryTest with ParquetTest with SharedSparkSession { + test("[SPARK-37933] test limit pushdown for vectorized parquet reader") { + import testImplicits._ + withSQLConf( + SQLConf.PARQUET_LIMIT_PUSHDOWN_ENABLED.key -> "true", + SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true", + SQLConf.USE_V1_SOURCE_LIST.key -> "") { + withTempPath { path => + (1 to 1024).map(i => (101, i)).toDF("a", "b").coalesce(1).write.parquet(path.getPath) + val pushedLimit = Random.nextInt(100) + val df = spark.read.parquet(path.getPath).limit(pushedLimit) + val sparkPlan = df.queryExecution.sparkPlan + sparkPlan foreachUp { + case r @ BatchScanExec(_, f: ParquetScan, _) => + assert(f.pushedLimit.contains(pushedLimit)) + assert(r.executeColumnar().map(_.numRows()).sum() == pushedLimit) + case CollectLimitExec(limit, _) => + assert(limit == pushedLimit) + } + assert(df.count() == pushedLimit) + } + } + } +}