Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,14 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val PARQUET_LIMIT_PUSHDOWN_ENABLED =
buildConf("spark.sql.parquet.limitPushdown.enabled")
.doc("Enables Parquet limit push-down optimization when set to true.")
.version("3.3.0")
.internal()
.booleanConf
.createWithDefault(true)

val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat")
.doc("If true, data will be written in a way of Spark 1.4 and earlier. For example, decimal " +
"values will be written in Apache Parquet's fixed-length byte array format, which other " +
Expand Down Expand Up @@ -3838,6 +3846,8 @@ class SQLConf extends Serializable with Logging {

def parquetAggregatePushDown: Boolean = getConf(PARQUET_AGGREGATE_PUSHDOWN_ENABLED)

def parquetLimitPushDownEnabled: Boolean = getConf(PARQUET_LIMIT_PUSHDOWN_ENABLED)

def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED)

def orcAggregatePushDown: Boolean = getConf(ORC_AGGREGATE_PUSHDOWN_ENABLED)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa
// The capacity of vectorized batch.
private int capacity;

// The pushed down Limit to read
private int limit;

/**
* Batch of rows that we assemble and the current index we've returned. Every time this
* batch is used up (batchIdx == numBatched), we populated the batch.
Expand Down Expand Up @@ -139,14 +142,28 @@ public VectorizedParquetRecordReader(
String int96RebaseMode,
String int96RebaseTz,
boolean useOffHeap,
int capacity) {
int capacity,
int limit) {
this.convertTz = convertTz;
this.datetimeRebaseMode = datetimeRebaseMode;
this.datetimeRebaseTz = datetimeRebaseTz;
this.int96RebaseMode = int96RebaseMode;
this.int96RebaseTz = int96RebaseTz;
MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP;
this.capacity = capacity;
this.limit = limit;
}

public VectorizedParquetRecordReader(
ZoneId convertTz,
String datetimeRebaseMode,
String datetimeRebaseTz,
String int96RebaseMode,
String int96RebaseTz,
boolean useOffHeap,
int capacity) {
this(convertTz, datetimeRebaseMode, datetimeRebaseTz, int96RebaseMode, int96RebaseTz,
useOffHeap, capacity, Integer.MAX_VALUE);
}

// For test only.
Expand Down Expand Up @@ -302,10 +319,13 @@ public boolean nextBatch() throws IOException {
vector.reset();
}
columnarBatch.setNumRows(0);
if (rowsReturned >= totalRowCount) return false;
if (rowsReturned >= totalRowCount || rowsReturned >= limit) {
return false;
}
checkEndOfRowGroup();

int num = (int) Math.min(capacity, totalCountLoadedSoFar - rowsReturned);
int num = (int) Math.min(capacity,
Math.min(limit - rowsReturned, totalCountLoadedSoFar - rowsReturned));
for (int i = 0; i < columnReaders.length; ++i) {
if (columnReaders[i] == null) continue;
columnReaders[i].readBatch(num, columnVectors[i]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import scala.collection.mutable

import org.apache.spark.sql.{sources, SparkSession}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.connector.read.{ScanBuilder, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.connector.read.{ScanBuilder, SupportsPushDownLimit, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, DataSourceUtils, PartitioningAwareFileIndex, PartitioningUtils}
import org.apache.spark.sql.internal.connector.SupportsPushDownCatalystFilters
import org.apache.spark.sql.sources.Filter
Expand All @@ -32,14 +32,16 @@ abstract class FileScanBuilder(
dataSchema: StructType)
extends ScanBuilder
with SupportsPushDownRequiredColumns
with SupportsPushDownCatalystFilters {
with SupportsPushDownCatalystFilters
with SupportsPushDownLimit {
private val partitionSchema = fileIndex.partitionSchema
private val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
protected val supportsNestedSchemaPruning = false
protected var requiredSchema = StructType(dataSchema.fields ++ partitionSchema.fields)
protected var partitionFilters = Seq.empty[Expression]
protected var dataFilters = Seq.empty[Expression]
protected var pushedDataFilters = Array.empty[Filter]
protected var pushedLimit: Option[Int] = None

override def pruneColumns(requiredSchema: StructType): Unit = {
// [SPARK-30107] While `requiredSchema` might have pruned nested columns,
Expand Down Expand Up @@ -84,6 +86,8 @@ abstract class FileScanBuilder(
dataFilters
}

override def pushLimit(limit: Int): Boolean = false

override def pushedFilters: Array[Filter] = pushedDataFilters

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ import org.apache.spark.util.SerializableConfiguration
* @param partitionSchema Schema of partitions.
* @param filters Filters to be pushed down in the batch scan.
* @param aggregation Aggregation to be pushed down in the batch scan.
* @param limit Limit to be pushed down in the batch scan.
* @param parquetOptions The options of Parquet datasource that are set for the read.
*/
case class ParquetPartitionReaderFactory(
Expand All @@ -66,6 +67,7 @@ case class ParquetPartitionReaderFactory(
partitionSchema: StructType,
filters: Array[Filter],
aggregation: Option[Aggregation],
limit: Option[Int],
parquetOptions: ParquetOptions) extends FilePartitionReaderFactory with Logging {
private val isCaseSensitive = sqlConf.caseSensitiveAnalysis
private val resultSchema = StructType(partitionSchema.fields ++ readDataSchema.fields)
Expand Down Expand Up @@ -321,7 +323,8 @@ case class ParquetPartitionReaderFactory(
int96RebaseSpec.mode.toString,
int96RebaseSpec.timeZone,
enableOffHeapColumnVector && taskContext.isDefined,
capacity)
capacity,
limit.getOrElse(Int.MaxValue))
val iter = new RecordReaderIterator(vectorizedReader)
// SPARK-23457 Register a task completion listener before `initialization`.
taskContext.foreach(_.addTaskCompletionListener[Unit](_ => iter.close()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ case class ParquetScan(
options: CaseInsensitiveStringMap,
pushedAggregate: Option[Aggregation] = None,
partitionFilters: Seq[Expression] = Seq.empty,
dataFilters: Seq[Expression] = Seq.empty) extends FileScan {
dataFilters: Seq[Expression] = Seq.empty,
pushedLimit: Option[Int] = None) extends FileScan {
override def isSplitable(path: Path): Boolean = {
// If aggregate is pushed down, only the file footer will be read once,
// so file should not be split across multiple tasks.
Expand Down Expand Up @@ -97,6 +98,7 @@ case class ParquetScan(
readPartitionSchema,
pushedFilters,
pushedAggregate,
pushedLimit,
new ParquetOptions(options.asCaseSensitiveMap.asScala.toMap, sqlConf))
}

Expand Down Expand Up @@ -124,7 +126,8 @@ case class ParquetScan(
override def description(): String = {
super.description() + ", PushedFilters: " + seqToString(pushedFilters) +
", PushedAggregation: " + pushedAggregationsStr +
", PushedGroupBy: " + pushedGroupByStr
", PushedGroupBy: " + pushedGroupByStr +
", PushedLimit: " + pushedLimit
}

override def getMetaData(): Map[String, String] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,17 @@ case class ParquetScanBuilder(
}
}

override def pushLimit(limit: Int): Boolean = {
val sqlConf = sparkSession.sessionState.conf
// TODO: Support limit push down for row based parquet reader
if (sqlConf.parquetVectorizedReaderEnabled && sqlConf.parquetLimitPushDownEnabled) {
pushedLimit = Some(limit)
true
} else {
false
}
}

override def build(): Scan = {
// the `finalSchema` is either pruned in pushAggregation (if aggregates are
// pushed down), or pruned in readDataSchema() (in regular column pruning). These
Expand All @@ -114,6 +125,6 @@ case class ParquetScanBuilder(
}
ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, finalSchema,
readPartitionSchema(), pushedParquetFilters, options, pushedAggregations,
partitionFilters, dataFilters)
partitionFilters, dataFilters, pushedLimit)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources.parquet

import scala.util.Random

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.execution.CollectLimitExec
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession

/**
* A test suite that tests Parquet based limit pushdown optimization.
*/
class ParquetLimitPushDownSuite extends QueryTest with ParquetTest with SharedSparkSession {
test("[SPARK-37933] test limit pushdown for vectorized parquet reader") {
import testImplicits._
withSQLConf(
SQLConf.PARQUET_LIMIT_PUSHDOWN_ENABLED.key -> "true",
SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true",
SQLConf.USE_V1_SOURCE_LIST.key -> "") {
withTempPath { path =>
(1 to 1024).map(i => (101, i)).toDF("a", "b").coalesce(1).write.parquet(path.getPath)
val pushedLimit = Random.nextInt(100)
val df = spark.read.parquet(path.getPath).limit(pushedLimit)
val sparkPlan = df.queryExecution.sparkPlan
sparkPlan foreachUp {
case r @ BatchScanExec(_, f: ParquetScan, _) =>
assert(f.pushedLimit.contains(pushedLimit))
assert(r.executeColumnar().map(_.numRows()).sum() == pushedLimit)
case CollectLimitExec(limit, _) =>
assert(limit == pushedLimit)
}
assert(df.count() == pushedLimit)
}
}
}
}