Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,7 @@ object SQLConf {
"false, this configuration does not take any effect.")
.version("3.1.0")
.booleanConf
.createWithDefault(false)
.createWithDefault(true)

val CROSS_JOINS_ENABLED = buildConf("spark.sql.crossJoin.enabled")
.internal()
Expand Down
21 changes: 20 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.{SPARK_VERSION, SparkConf, SparkContext, TaskContext}
import org.apache.spark.annotation.{DeveloperApi, Experimental, Stable, Unstable}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.EXECUTOR_ALLOW_SPARK_CONTEXT
import org.apache.spark.internal.config.{ConfigEntry, EXECUTOR_ALLOW_SPARK_CONTEXT}
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.catalog.Catalog
Expand Down Expand Up @@ -1077,6 +1077,25 @@ object SparkSession extends Logging {
throw new IllegalStateException("No active or default Spark session found")))
}

/**
* Returns a cloned SparkSession with all specified configurations disabled, or
* the original SparkSession if all configurations are already disabled.
*/
private[sql] def getOrCloneSessionWithConfigsOff(
session: SparkSession,
configurations: Seq[ConfigEntry[Boolean]]): SparkSession = {
val configsEnabled = configurations.filter(session.sessionState.conf.getConf(_))
if (configsEnabled.isEmpty) {
session
} else {
val newSession = session.cloneSession()
configsEnabled.foreach(conf => {
newSession.sessionState.conf.setConf(conf, false)
})
newSession
}
}

////////////////////////////////////////////////////////////////////////////////////////
// Private methods from now on
////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.collection.immutable.IndexedSeq
import org.apache.hadoop.fs.{FileSystem, Path}

import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.ConfigEntry
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, SubqueryExpression}
import org.apache.spark.sql.catalyst.optimizer.EliminateResolvedHint
Expand All @@ -31,6 +32,7 @@ import org.apache.spark.sql.execution.columnar.{DefaultCachedBatchSerializer, In
import org.apache.spark.sql.execution.command.CommandUtils
import org.apache.spark.sql.execution.datasources.{FileIndex, HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, FileTable}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.storage.StorageLevel
import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK

Expand All @@ -55,6 +57,17 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
@transient @volatile
private var cachedData = IndexedSeq[CachedData]()

/**
* Configurations needs to be turned off, to avoid regression for cached query, so that the
* outputPartitioning of the underlying cached query plan can be leveraged later.
* Configurations include:
* 1. AQE
* 2. Automatic bucketed table scan
*/
private val forceDisableConfigs: Seq[ConfigEntry[Boolean]] = Seq(
SQLConf.ADAPTIVE_EXECUTION_ENABLED,
SQLConf.AUTO_BUCKETED_SCAN_ENABLED)

/** Clears all cached tables. */
def clearCache(): Unit = this.synchronized {
cachedData.foreach(_.cachedRepresentation.cacheBuilder.clearCache())
Expand All @@ -79,10 +92,10 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
if (lookupCachedData(planToCache).nonEmpty) {
logWarning("Asked to cache already cached data.")
} else {
// Turn off AQE so that the outputPartitioning of the underlying plan can be leveraged.
val sessionWithAqeOff = getOrCloneSessionWithAqeOff(query.sparkSession)
val inMemoryRelation = sessionWithAqeOff.withActive {
val qe = sessionWithAqeOff.sessionState.executePlan(planToCache)
val sessionWithConfigsOff = SparkSession.getOrCloneSessionWithConfigsOff(
query.sparkSession, forceDisableConfigs)
val inMemoryRelation = sessionWithConfigsOff.withActive {
val qe = sessionWithConfigsOff.sessionState.executePlan(planToCache)
InMemoryRelation(
storageLevel,
qe,
Expand Down Expand Up @@ -188,10 +201,10 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
}
needToRecache.map { cd =>
cd.cachedRepresentation.cacheBuilder.clearCache()
// Turn off AQE so that the outputPartitioning of the underlying plan can be leveraged.
val sessionWithAqeOff = getOrCloneSessionWithAqeOff(spark)
val newCache = sessionWithAqeOff.withActive {
val qe = sessionWithAqeOff.sessionState.executePlan(cd.plan)
val sessionWithConfigsOff = SparkSession.getOrCloneSessionWithConfigsOff(
spark, forceDisableConfigs)
val newCache = sessionWithConfigsOff.withActive {
val qe = sessionWithConfigsOff.sessionState.executePlan(cd.plan)
InMemoryRelation(cd.cachedRepresentation.cacheBuilder, qe)
}
val recomputedPlan = cd.copy(cachedRepresentation = newCache)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@

package org.apache.spark.sql.execution.adaptive

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.internal.SQLConf

/**
* This class provides utility methods related to tree traversal of an [[AdaptiveSparkPlanExec]]
Expand Down Expand Up @@ -137,18 +135,4 @@ trait AdaptiveSparkPlanHelper {
case a: AdaptiveSparkPlanExec => a.executedPlan
case other => other
}

/**
* Returns a cloned [[SparkSession]] with adaptive execution disabled, or the original
* [[SparkSession]] if its adaptive execution is already disabled.
*/
def getOrCloneSessionWithAqeOff[T](session: SparkSession): SparkSession = {
if (!session.sessionState.conf.adaptiveExecutionEnabled) {
session
} else {
val newSession = session.cloneSession()
newSession.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED, false)
newSession
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -262,20 +262,22 @@ class FileSourceStrategySuite extends QueryTest with SharedSparkSession with Pre
"p1=2/file7_0000" -> 1),
buckets = 3)

// No partition pruning
checkScan(table) { partitions =>
assert(partitions.size == 3)
assert(partitions(0).files.size == 5)
assert(partitions(1).files.size == 0)
assert(partitions(2).files.size == 2)
}
withSQLConf(SQLConf.AUTO_BUCKETED_SCAN_ENABLED.key -> "false") {
// No partition pruning
checkScan(table) { partitions =>
assert(partitions.size == 3)
assert(partitions(0).files.size == 5)
assert(partitions(1).files.size == 0)
assert(partitions(2).files.size == 2)
}

// With partition pruning
checkScan(table.where("p1=2")) { partitions =>
assert(partitions.size == 3)
assert(partitions(0).files.size == 3)
assert(partitions(1).files.size == 0)
assert(partitions(2).files.size == 1)
// With partition pruning
checkScan(table.where("p1=2")) { partitions =>
assert(partitions.size == 3)
assert(partitions(0).files.size == 3)
assert(partitions(1).files.size == 0)
assert(partitions(2).files.size == 1)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,22 +432,24 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils
// join1 is a broadcast join where df2 is broadcasted. Note that output partitioning on the
// streamed side (t1) is HashPartitioning (bucketed files).
val join1 = t1.join(df2, t1("i1") === df2("i2") && t1("j1") === df2("j2"))
val plan1 = join1.queryExecution.executedPlan
assert(collect(plan1) { case e: ShuffleExchangeExec => e }.isEmpty)
val broadcastJoins = collect(plan1) { case b: BroadcastHashJoinExec => b }
assert(broadcastJoins.size == 1)
assert(broadcastJoins(0).outputPartitioning.isInstanceOf[PartitioningCollection])
val p = broadcastJoins(0).outputPartitioning.asInstanceOf[PartitioningCollection]
assert(p.partitionings.size == 4)
// Verify all the combinations of output partitioning.
Seq(Seq(t1("i1"), t1("j1")),
Seq(t1("i1"), df2("j2")),
Seq(df2("i2"), t1("j1")),
Seq(df2("i2"), df2("j2"))).foreach { expected =>
val expectedExpressions = expected.map(_.expr)
assert(p.partitionings.exists {
case h: HashPartitioning => expressionsEqual(h.expressions, expectedExpressions)
})
withSQLConf(SQLConf.AUTO_BUCKETED_SCAN_ENABLED.key -> "false") {
val plan1 = join1.queryExecution.executedPlan
assert(collect(plan1) { case e: ShuffleExchangeExec => e }.isEmpty)
val broadcastJoins = collect(plan1) { case b: BroadcastHashJoinExec => b }
assert(broadcastJoins.size == 1)
assert(broadcastJoins(0).outputPartitioning.isInstanceOf[PartitioningCollection])
val p = broadcastJoins(0).outputPartitioning.asInstanceOf[PartitioningCollection]
assert(p.partitionings.size == 4)
// Verify all the combinations of output partitioning.
Seq(Seq(t1("i1"), t1("j1")),
Seq(t1("i1"), df2("j2")),
Seq(df2("i2"), t1("j1")),
Seq(df2("i2"), df2("j2"))).foreach { expected =>
val expectedExpressions = expected.map(_.expr)
assert(p.partitionings.exists {
case h: HashPartitioning => expressionsEqual(h.expressions, expectedExpressions)
})
}
}

// Join on the column from the broadcasted side (i2, j2) and make sure output partitioning
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,22 +81,24 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
.bucketBy(8, "j", "k")
.saveAsTable("bucketed_table")

val bucketValue = Random.nextInt(maxI)
val table = spark.table("bucketed_table").filter($"i" === bucketValue)
val query = table.queryExecution
val output = query.analyzed.output
val rdd = query.toRdd

assert(rdd.partitions.length == 8)

val attrs = table.select("j", "k").queryExecution.analyzed.output
val checkBucketId = rdd.mapPartitionsWithIndex((index, rows) => {
val getBucketId = UnsafeProjection.create(
HashPartitioning(attrs, 8).partitionIdExpression :: Nil,
output)
rows.map(row => getBucketId(row).getInt(0) -> index)
})
checkBucketId.collect().foreach(r => assert(r._1 == r._2))
withSQLConf(SQLConf.AUTO_BUCKETED_SCAN_ENABLED.key -> "false") {
val bucketValue = Random.nextInt(maxI)
val table = spark.table("bucketed_table").filter($"i" === bucketValue)
val query = table.queryExecution
val output = query.analyzed.output
val rdd = query.toRdd

assert(rdd.partitions.length == 8)

val attrs = table.select("j", "k").queryExecution.analyzed.output
val checkBucketId = rdd.mapPartitionsWithIndex((index, rows) => {
val getBucketId = UnsafeProjection.create(
HashPartitioning(attrs, 8).partitionIdExpression :: Nil,
output)
rows.map(row => getBucketId(row).getInt(0) -> index)
})
checkBucketId.collect().foreach(r => assert(r._1 == r._2))
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
package org.apache.spark.sql.sources

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
Expand Down Expand Up @@ -218,4 +221,24 @@ abstract class DisableUnnecessaryBucketedScanSuite extends QueryTest with SQLTes
}
}
}

test("SPARK-33075: not disable bucketed table scan for cached query") {
withTable("t1") {
withSQLConf(SQLConf.AUTO_BUCKETED_SCAN_ENABLED.key -> "true") {
df1.write.format("parquet").bucketBy(8, "i").saveAsTable("t1")
spark.catalog.cacheTable("t1")
assertCached(spark.table("t1"))

// Verify cached bucketed table scan not disabled
val partitioning = spark.table("t1").queryExecution.executedPlan
.outputPartitioning
assert(partitioning match {
case HashPartitioning(Seq(column: AttributeReference), 8) if column.name == "i" => true
case _ => false
})
val aggregateQueryPlan = sql("SELECT SUM(i) FROM t1 GROUP BY i").queryExecution.executedPlan
assert(aggregateQueryPlan.find(_.isInstanceOf[ShuffleExchangeExec]).isEmpty)
}
}
}
}