diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 189740e313207..4b9caa7063d6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -371,6 +371,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + val DEFAULT_PARALLELISM = buildConf("spark.sql.default.parallelism") + .doc("The session-local default number of partitions and this value is widely used " + + "inside physical plans. If not set, the physical plans refer to " + + "`spark.default.parallelism` instead.") + .version("3.1.0") + .intConf + .checkValue(_ > 0, "The value of spark.sql.default.parallelism must be positive") + .createOptional + val SHUFFLE_PARTITIONS = buildConf("spark.sql.shuffle.partitions") .doc("The default number of partitions to use when shuffling data for joins or aggregations. " + "Note: For structured streaming, this configuration cannot be changed between query " + @@ -2784,6 +2793,8 @@ class SQLConf extends Serializable with Logging { def cacheVectorizedReaderEnabled: Boolean = getConf(CACHE_VECTORIZED_READER_ENABLED) + def defaultParallelism: Option[Int] = getConf(DEFAULT_PARALLELISM) + def defaultNumShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) def numShufflePartitions: Int = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 60a60377d8a3f..ed06c0deb8d3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -180,6 +180,15 @@ class SparkSession private( */ @transient lazy val conf: RuntimeConfig = new RuntimeConfig(sessionState.conf) + /** + * Same as `spark.default.parallelism`, can be isolated across sessions. + * + * @since 3.1.0 + */ + def defaultParallelism: Int = { + sessionState.conf.defaultParallelism.getOrElse(sparkContext.defaultParallelism) + } + /** * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s * that listen for execution metrics. @@ -513,7 +522,7 @@ class SparkSession private( * @since 2.0.0 */ def range(start: Long, end: Long): Dataset[java.lang.Long] = { - range(start, end, step = 1, numPartitions = sparkContext.defaultParallelism) + range(start, end, step = 1, numPartitions = defaultParallelism) } /** @@ -523,7 +532,7 @@ class SparkSession private( * @since 2.0.0 */ def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = { - range(start, end, step, numPartitions = sparkContext.defaultParallelism) + range(start, end, step, numPartitions = defaultParallelism) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index b452213cd6cc7..ff3ed628a84e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -49,7 +49,7 @@ case class LocalTableScanExec( if (rows.isEmpty) { sqlContext.sparkContext.emptyRDD } else { - val numSlices = math.min(unsafeRows.length, sqlContext.sparkContext.defaultParallelism) + val numSlices = math.min(unsafeRows.length, sqlContext.sparkSession.defaultParallelism) sqlContext.sparkContext.parallelize(unsafeRows, numSlices) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala index 84c65df31a7c5..e3a608522cc15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala @@ -65,7 +65,7 @@ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPl // We fall back to Spark default parallelism if the minimum number of coalesced partitions // is not set, so to avoid perf regressions compared to no coalescing. val minPartitionNum = conf.getConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM) - .getOrElse(session.sparkContext.defaultParallelism) + .getOrElse(session.defaultParallelism) val partitionSpecs = ShufflePartitionsUtil.coalescePartitions( validMetrics.toArray, advisoryTargetSize = conf.getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 4b376b94566b8..b9ad6480f0ca8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -369,7 +369,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val start: Long = range.start val end: Long = range.end val step: Long = range.step - val numSlices: Int = range.numSlices.getOrElse(sparkContext.defaultParallelism) + val numSlices: Int = range.numSlices.getOrElse(sqlContext.sparkSession.defaultParallelism) val numElements: BigInt = range.numElements override val output: Seq[Attribute] = range.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 47b213fc2d83b..f62a7f5bac3d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -737,7 +737,7 @@ case class AlterTableRecoverPartitionsCommand( // Set the number of parallelism to prevent following file listing from generating many tasks // in case of large #defaultParallelism. val numParallelism = Math.min(serializedPaths.length, - Math.min(spark.sparkContext.defaultParallelism, 10000)) + Math.min(spark.defaultParallelism, 10000)) // gather the fast stats for all the partitions otherwise Hive metastore will list all the // files for all the new partitions in sequential way, which is super slow. logInfo(s"Gather the fast stats in parallel using $numParallelism tasks.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FilePartition.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FilePartition.scala index b4fc94e097aa8..c0152249c9b9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FilePartition.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FilePartition.scala @@ -88,7 +88,7 @@ object FilePartition extends Logging { selectedPartitions: Seq[PartitionDirectory]): Long = { val defaultMaxSplitBytes = sparkSession.sessionState.conf.filesMaxPartitionBytes val openCostInBytes = sparkSession.sessionState.conf.filesOpenCostInBytes - val defaultParallelism = sparkSession.sparkContext.defaultParallelism + val defaultParallelism = sparkSession.defaultParallelism val totalBytes = selectedPartitions.flatMap(_.files.map(_.getLen + openCostInBytes)).sum val bytesPerCore = totalBytes / defaultParallelism diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaMergeUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaMergeUtils.scala index 99882b0f7c7b0..6115b4aaaacf5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaMergeUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaMergeUtils.scala @@ -55,7 +55,7 @@ object SchemaMergeUtils extends Logging { // Set the number of partitions to prevent following schema reads from generating many tasks // in case of a small number of orc files. val numParallelism = Math.min(Math.max(partialFileStatusInfo.size, 1), - sparkSession.sparkContext.defaultParallelism) + sparkSession.defaultParallelism) val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala index a093bf54b2107..beb19d496da7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -65,7 +65,7 @@ class RateStreamProvider extends SimpleTableProvider with DataSourceRegister { } val numPartitions = options.getInt( - NUM_PARTITIONS, SparkSession.active.sparkContext.defaultParallelism) + NUM_PARTITIONS, SparkSession.active.defaultParallelism) if (numPartitions <= 0) { throw new IllegalArgumentException( s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala index a4dcb2049eb87..a4820c4babd79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala @@ -60,7 +60,7 @@ class TextSocketSourceProvider extends SimpleTableProvider with DataSourceRegist new TextSocketTable( options.get("host"), options.getInt("port", -1), - options.getInt("numPartitions", SparkSession.active.sparkContext.defaultParallelism), + options.getInt("numPartitions", SparkSession.active.defaultParallelism), options.getBoolean("includeTimestamp", false)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala index 003f5bc835d5f..03e11499b7c22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.util.QueryExecutionListener class SessionStateSuite extends SparkFunSuite { @@ -239,4 +240,15 @@ class SessionStateSuite extends SparkFunSuite { activeSession.conf.unset(key) } } + + test("add spark.sql.default.parallelism in SQLConf") { + val key = SQLConf.DEFAULT_PARALLELISM.key + try { + assert(activeSession.defaultParallelism == activeSession.sparkContext.defaultParallelism) + activeSession.conf.set(key, "1") + assert(activeSession.defaultParallelism == 1) + } finally { + activeSession.conf.unset(key) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index bbd0220a74f88..1dc7f1f78d771 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -1307,7 +1307,7 @@ class FakeDefaultSource extends FakeSource { startOffset, end.asInstanceOf[LongOffset].offset + 1, 1, - Some(spark.sparkSession.sparkContext.defaultParallelism), + Some(spark.sparkSession.defaultParallelism), isStreaming = true), Encoders.LONG) ds.toDF("a")