diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 35ef24c1c3ba6..3024398399962 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -971,7 +971,7 @@ object SQLConf { "false, this configuration does not take any effect.") .version("3.1.0") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val CROSS_JOINS_ENABLED = buildConf("spark.sql.crossJoin.enabled") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index b15d6f981291c..b33557dbfdb27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -29,7 +29,7 @@ import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext} import org.apache.spark.annotation.{DeveloperApi, Experimental, Stable, Unstable} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.EXECUTOR_ALLOW_SPARK_CONTEXT +import org.apache.spark.internal.config.{ConfigEntry, EXECUTOR_ALLOW_SPARK_CONTEXT} import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.catalog.Catalog @@ -1077,6 +1077,25 @@ object SparkSession extends Logging { throw new IllegalStateException("No active or default Spark session found"))) } + /** + * Returns a cloned SparkSession with all specified configurations disabled, or + * the original SparkSession if all configurations are already disabled. + */ + private[sql] def getOrCloneSessionWithConfigsOff( + session: SparkSession, + configurations: Seq[ConfigEntry[Boolean]]): SparkSession = { + val configsEnabled = configurations.filter(session.sessionState.conf.getConf(_)) + if (configsEnabled.isEmpty) { + session + } else { + val newSession = session.cloneSession() + configsEnabled.foreach(conf => { + newSession.sessionState.conf.setConf(conf, false) + }) + newSession + } + } + //////////////////////////////////////////////////////////////////////////////////////// // Private methods from now on //////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 7201026b11b6b..5f72d6005a8dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -22,6 +22,7 @@ import scala.collection.immutable.IndexedSeq import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.ConfigEntry import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.expressions.{Attribute, SubqueryExpression} import org.apache.spark.sql.catalyst.optimizer.EliminateResolvedHint @@ -31,6 +32,7 @@ import org.apache.spark.sql.execution.columnar.{DefaultCachedBatchSerializer, In import org.apache.spark.sql.execution.command.CommandUtils import org.apache.spark.sql.execution.datasources.{FileIndex, HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, FileTable} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK @@ -55,6 +57,17 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { @transient @volatile private var cachedData = IndexedSeq[CachedData]() + /** + * Configurations needs to be turned off, to avoid regression for cached query, so that the + * outputPartitioning of the underlying cached query plan can be leveraged later. + * Configurations include: + * 1. AQE + * 2. Automatic bucketed table scan + */ + private val forceDisableConfigs: Seq[ConfigEntry[Boolean]] = Seq( + SQLConf.ADAPTIVE_EXECUTION_ENABLED, + SQLConf.AUTO_BUCKETED_SCAN_ENABLED) + /** Clears all cached tables. */ def clearCache(): Unit = this.synchronized { cachedData.foreach(_.cachedRepresentation.cacheBuilder.clearCache()) @@ -79,10 +92,10 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { if (lookupCachedData(planToCache).nonEmpty) { logWarning("Asked to cache already cached data.") } else { - // Turn off AQE so that the outputPartitioning of the underlying plan can be leveraged. - val sessionWithAqeOff = getOrCloneSessionWithAqeOff(query.sparkSession) - val inMemoryRelation = sessionWithAqeOff.withActive { - val qe = sessionWithAqeOff.sessionState.executePlan(planToCache) + val sessionWithConfigsOff = SparkSession.getOrCloneSessionWithConfigsOff( + query.sparkSession, forceDisableConfigs) + val inMemoryRelation = sessionWithConfigsOff.withActive { + val qe = sessionWithConfigsOff.sessionState.executePlan(planToCache) InMemoryRelation( storageLevel, qe, @@ -188,10 +201,10 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { } needToRecache.map { cd => cd.cachedRepresentation.cacheBuilder.clearCache() - // Turn off AQE so that the outputPartitioning of the underlying plan can be leveraged. - val sessionWithAqeOff = getOrCloneSessionWithAqeOff(spark) - val newCache = sessionWithAqeOff.withActive { - val qe = sessionWithAqeOff.sessionState.executePlan(cd.plan) + val sessionWithConfigsOff = SparkSession.getOrCloneSessionWithConfigsOff( + spark, forceDisableConfigs) + val newCache = sessionWithConfigsOff.withActive { + val qe = sessionWithConfigsOff.sessionState.executePlan(cd.plan) InMemoryRelation(cd.cachedRepresentation.cacheBuilder, qe) } val recomputedPlan = cd.copy(cachedRepresentation = newCache) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanHelper.scala index 8d7a2c95081c4..6ba375910a4eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanHelper.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.execution.adaptive -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.internal.SQLConf /** * This class provides utility methods related to tree traversal of an [[AdaptiveSparkPlanExec]] @@ -137,18 +135,4 @@ trait AdaptiveSparkPlanHelper { case a: AdaptiveSparkPlanExec => a.executedPlan case other => other } - - /** - * Returns a cloned [[SparkSession]] with adaptive execution disabled, or the original - * [[SparkSession]] if its adaptive execution is already disabled. - */ - def getOrCloneSessionWithAqeOff[T](session: SparkSession): SparkSession = { - if (!session.sessionState.conf.adaptiveExecutionEnabled) { - session - } else { - val newSession = session.cloneSession() - newSession.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED, false) - newSession - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index dfd9ba03f5be0..50f32126e5dec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -262,20 +262,22 @@ class FileSourceStrategySuite extends QueryTest with SharedSparkSession with Pre "p1=2/file7_0000" -> 1), buckets = 3) - // No partition pruning - checkScan(table) { partitions => - assert(partitions.size == 3) - assert(partitions(0).files.size == 5) - assert(partitions(1).files.size == 0) - assert(partitions(2).files.size == 2) - } + withSQLConf(SQLConf.AUTO_BUCKETED_SCAN_ENABLED.key -> "false") { + // No partition pruning + checkScan(table) { partitions => + assert(partitions.size == 3) + assert(partitions(0).files.size == 5) + assert(partitions(1).files.size == 0) + assert(partitions(2).files.size == 2) + } - // With partition pruning - checkScan(table.where("p1=2")) { partitions => - assert(partitions.size == 3) - assert(partitions(0).files.size == 3) - assert(partitions(1).files.size == 0) - assert(partitions(2).files.size == 1) + // With partition pruning + checkScan(table.where("p1=2")) { partitions => + assert(partitions.size == 3) + assert(partitions(0).files.size == 3) + assert(partitions(1).files.size == 0) + assert(partitions(2).files.size == 1) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 7ff945f5cbfb4..b6d1baf6e7902 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -432,22 +432,24 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils // join1 is a broadcast join where df2 is broadcasted. Note that output partitioning on the // streamed side (t1) is HashPartitioning (bucketed files). val join1 = t1.join(df2, t1("i1") === df2("i2") && t1("j1") === df2("j2")) - val plan1 = join1.queryExecution.executedPlan - assert(collect(plan1) { case e: ShuffleExchangeExec => e }.isEmpty) - val broadcastJoins = collect(plan1) { case b: BroadcastHashJoinExec => b } - assert(broadcastJoins.size == 1) - assert(broadcastJoins(0).outputPartitioning.isInstanceOf[PartitioningCollection]) - val p = broadcastJoins(0).outputPartitioning.asInstanceOf[PartitioningCollection] - assert(p.partitionings.size == 4) - // Verify all the combinations of output partitioning. - Seq(Seq(t1("i1"), t1("j1")), - Seq(t1("i1"), df2("j2")), - Seq(df2("i2"), t1("j1")), - Seq(df2("i2"), df2("j2"))).foreach { expected => - val expectedExpressions = expected.map(_.expr) - assert(p.partitionings.exists { - case h: HashPartitioning => expressionsEqual(h.expressions, expectedExpressions) - }) + withSQLConf(SQLConf.AUTO_BUCKETED_SCAN_ENABLED.key -> "false") { + val plan1 = join1.queryExecution.executedPlan + assert(collect(plan1) { case e: ShuffleExchangeExec => e }.isEmpty) + val broadcastJoins = collect(plan1) { case b: BroadcastHashJoinExec => b } + assert(broadcastJoins.size == 1) + assert(broadcastJoins(0).outputPartitioning.isInstanceOf[PartitioningCollection]) + val p = broadcastJoins(0).outputPartitioning.asInstanceOf[PartitioningCollection] + assert(p.partitionings.size == 4) + // Verify all the combinations of output partitioning. + Seq(Seq(t1("i1"), t1("j1")), + Seq(t1("i1"), df2("j2")), + Seq(df2("i2"), t1("j1")), + Seq(df2("i2"), df2("j2"))).foreach { expected => + val expectedExpressions = expected.map(_.expr) + assert(p.partitionings.exists { + case h: HashPartitioning => expressionsEqual(h.expressions, expectedExpressions) + }) + } } // Join on the column from the broadcasted side (i2, j2) and make sure output partitioning diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index f8276b143c1e6..a188e4d9d6d90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -81,22 +81,24 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { .bucketBy(8, "j", "k") .saveAsTable("bucketed_table") - val bucketValue = Random.nextInt(maxI) - val table = spark.table("bucketed_table").filter($"i" === bucketValue) - val query = table.queryExecution - val output = query.analyzed.output - val rdd = query.toRdd - - assert(rdd.partitions.length == 8) - - val attrs = table.select("j", "k").queryExecution.analyzed.output - val checkBucketId = rdd.mapPartitionsWithIndex((index, rows) => { - val getBucketId = UnsafeProjection.create( - HashPartitioning(attrs, 8).partitionIdExpression :: Nil, - output) - rows.map(row => getBucketId(row).getInt(0) -> index) - }) - checkBucketId.collect().foreach(r => assert(r._1 == r._2)) + withSQLConf(SQLConf.AUTO_BUCKETED_SCAN_ENABLED.key -> "false") { + val bucketValue = Random.nextInt(maxI) + val table = spark.table("bucketed_table").filter($"i" === bucketValue) + val query = table.queryExecution + val output = query.analyzed.output + val rdd = query.toRdd + + assert(rdd.partitions.length == 8) + + val attrs = table.select("j", "k").queryExecution.analyzed.output + val checkBucketId = rdd.mapPartitionsWithIndex((index, rows) => { + val getBucketId = UnsafeProjection.create( + HashPartitioning(attrs, 8).partitionIdExpression :: Nil, + output) + rows.map(row => getBucketId(row).getInt(0) -> index) + }) + checkBucketId.collect().foreach(r => assert(r._1 == r._2)) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala index 1c258bc0dadb9..70b74aed40eca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DisableUnnecessaryBucketedScanSuite.scala @@ -18,7 +18,10 @@ package org.apache.spark.sql.sources import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.FileSourceScanExec +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} @@ -218,4 +221,24 @@ abstract class DisableUnnecessaryBucketedScanSuite extends QueryTest with SQLTes } } } + + test("SPARK-33075: not disable bucketed table scan for cached query") { + withTable("t1") { + withSQLConf(SQLConf.AUTO_BUCKETED_SCAN_ENABLED.key -> "true") { + df1.write.format("parquet").bucketBy(8, "i").saveAsTable("t1") + spark.catalog.cacheTable("t1") + assertCached(spark.table("t1")) + + // Verify cached bucketed table scan not disabled + val partitioning = spark.table("t1").queryExecution.executedPlan + .outputPartitioning + assert(partitioning match { + case HashPartitioning(Seq(column: AttributeReference), 8) if column.name == "i" => true + case _ => false + }) + val aggregateQueryPlan = sql("SELECT SUM(i) FROM t1 GROUP BY i").queryExecution.executedPlan + assert(aggregateQueryPlan.find(_.isInstanceOf[ShuffleExchangeExec]).isEmpty) + } + } + } }