diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 02fb73ed30680..38e63d425bb21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -542,10 +542,9 @@ case class FileSourceScanExec( }.groupBy { f => BucketingUtils .getBucketId(new Path(f.filePath).getName) - .getOrElse(sys.error(s"Invalid bucket file ${f.filePath}")) + .getOrElse(throw new IllegalStateException(s"Invalid bucket file ${f.filePath}")) } - // TODO(SPARK-32985): Decouple bucket filter pruning and bucketed table scan val prunedFilesGroupedToBuckets = if (optionalBucketSet.isDefined) { val bucketSet = optionalBucketSet.get filesGroupedToBuckets.filter { @@ -591,20 +590,41 @@ case class FileSourceScanExec( logInfo(s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " + s"open cost is considered as scanning $openCostInBytes bytes.") + // Filter files with bucket pruning if possible + val bucketingEnabled = fsRelation.sparkSession.sessionState.conf.bucketingEnabled + val shouldProcess: Path => Boolean = optionalBucketSet match { + case Some(bucketSet) if bucketingEnabled => + filePath => { + BucketingUtils.getBucketId(filePath.getName) match { + case Some(id) => bucketSet.get(id) + case None => + // Do not prune the file if bucket file name is invalid + true + } + } + case _ => + _ => true + } + val splitFiles = selectedPartitions.flatMap { partition => partition.files.flatMap { file => // getPath() is very expensive so we only want to call it once in this block: val filePath = file.getPath - val isSplitable = relation.fileFormat.isSplitable( - relation.sparkSession, relation.options, filePath) - PartitionedFileUtil.splitFiles( - sparkSession = relation.sparkSession, - file = file, - filePath = filePath, - isSplitable = isSplitable, - maxSplitBytes = maxSplitBytes, - partitionValues = partition.values - ) + + if (shouldProcess(filePath)) { + val isSplitable = relation.fileFormat.isSplitable( + relation.sparkSession, relation.options, filePath) + PartitionedFileUtil.splitFiles( + sparkSession = relation.sparkSession, + file = file, + filePath = filePath, + isSplitable = isSplitable, + maxSplitBytes = maxSplitBytes, + partitionValues = partition.values + ) + } else { + Seq.empty + } } }.sortBy(_.length)(implicitly[Ordering[Long]].reverse) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/DisableUnnecessaryBucketedScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/DisableUnnecessaryBucketedScan.scala index 6b195b3b49f09..98bcab2a839af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/DisableUnnecessaryBucketedScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/bucketing/DisableUnnecessaryBucketedScan.scala @@ -98,7 +98,7 @@ object DisableUnnecessaryBucketedScan extends Rule[SparkPlan] { exchange.mapChildren(disableBucketWithInterestingPartition( _, withInterestingPartition, true, withAllowedNode)) case scan: FileSourceScanExec => - if (isBucketedScanWithoutFilter(scan)) { + if (scan.bucketedScan) { if (!withInterestingPartition || (withExchange && withAllowedNode)) { val nonBucketedScan = scan.copy(disableBucketedScan = true) scan.logicalLink.foreach(nonBucketedScan.setLogicalLink) @@ -140,20 +140,13 @@ object DisableUnnecessaryBucketedScan extends Rule[SparkPlan] { } } - private def isBucketedScanWithoutFilter(scan: FileSourceScanExec): Boolean = { - // Do not disable bucketed table scan if it has filter pruning, - // because bucketed table scan is still useful here to save CPU/IO cost with - // only reading selected bucket files. - scan.bucketedScan && scan.optionalBucketSet.isEmpty - } - def apply(plan: SparkPlan): SparkPlan = { - lazy val hasBucketedScanWithoutFilter = plan.find { - case scan: FileSourceScanExec => isBucketedScanWithoutFilter(scan) + lazy val hasBucketedScan = plan.find { + case scan: FileSourceScanExec => scan.bucketedScan case _ => false }.isDefined - if (!conf.bucketingEnabled || !conf.autoBucketedScanEnabled || !hasBucketedScanWithoutFilter) { + if (!conf.bucketingEnabled || !conf.autoBucketedScanEnabled || !hasBucketedScan) { plan } else { disableBucketWithInterestingPartition(plan, false, false, true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 9dcc0cfda93f1..d0f569cf675f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -117,9 +117,12 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti bucketValues: Seq[Any], filterCondition: Column, originalDataFrame: DataFrame): Unit = { - // This test verifies parts of the plan. Disable whole stage codegen. - withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { - val bucketedDataFrame = spark.table("bucketed_table").select("i", "j", "k") + // This test verifies parts of the plan. Disable whole stage codegen, + // automatically bucketed scan, and filter push down for json data source. + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.AUTO_BUCKETED_SCAN_ENABLED.key -> "false", + SQLConf.JSON_FILTER_PUSHDOWN_ENABLED.key -> "false") { + val bucketedDataFrame = spark.table("bucketed_table") val BucketSpec(numBuckets, bucketColumnNames, _) = bucketSpec // Limit: bucket pruning only works when the bucket column has one and only one column assert(bucketColumnNames.length == 1) @@ -148,11 +151,41 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti if (invalidBuckets.nonEmpty) { fail(s"Buckets ${invalidBuckets.mkString(",")} should have been pruned from:\n$plan") } + + withSQLConf(SQLConf.AUTO_BUCKETED_SCAN_ENABLED.key -> "true") { + // Bucket pruning should still work without bucketed scan + val planWithoutBucketedScan = bucketedDataFrame.filter(filterCondition) + .queryExecution.executedPlan + val fileScan = getFileScan(planWithoutBucketedScan) + assert(!fileScan.bucketedScan, s"except no bucketed scan but found\n$fileScan") + + val bucketColumnType = bucketedDataFrame.schema.apply(bucketColumnIndex).dataType + val rowsWithInvalidBuckets = fileScan.execute().filter(row => { + // Return rows should have been pruned + val bucketColumnValue = row.get(bucketColumnIndex, bucketColumnType) + val bucketId = BucketingUtils.getBucketIdFromValue( + bucketColumn, numBuckets, bucketColumnValue) + !matchedBuckets.get(bucketId) + }).collect() + + if (rowsWithInvalidBuckets.nonEmpty) { + fail(s"Rows ${rowsWithInvalidBuckets.mkString(",")} should have been pruned from:\n" + + s"$planWithoutBucketedScan") + } + } } + val expectedDataFrame = originalDataFrame.filter(filterCondition).orderBy("i", "j", "k") + .select("i", "j", "k") checkAnswer( - bucketedDataFrame.filter(filterCondition).orderBy("i", "j", "k"), - originalDataFrame.filter(filterCondition).orderBy("i", "j", "k")) + bucketedDataFrame.filter(filterCondition).orderBy("i", "j", "k").select("i", "j", "k"), + expectedDataFrame) + + withSQLConf(SQLConf.AUTO_BUCKETED_SCAN_ENABLED.key -> "true") { + checkAnswer( + bucketedDataFrame.filter(filterCondition).orderBy("i", "j", "k").select("i", "j", "k"), + expectedDataFrame) + } } } @@ -160,7 +193,6 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti withTable("bucketed_table") { val numBuckets = NumBucketsForPruningDF val bucketSpec = BucketSpec(numBuckets, Seq("j"), Nil) - // json does not support predicate push-down, and thus json is used here df.write .format("json") .partitionBy("i") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala index 179cdeb976391..1a19824a31555 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala @@ -95,7 +95,7 @@ abstract class DisableUnnecessaryBucketedScanSuite ("SELECT i FROM t1", 0, 1), ("SELECT j FROM t1", 0, 0), // Filter on bucketed column - ("SELECT * FROM t1 WHERE i = 1", 1, 1), + ("SELECT * FROM t1 WHERE i = 1", 0, 1), // Filter on non-bucketed column ("SELECT * FROM t1 WHERE j = 1", 0, 1), // Join with same buckets