Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,15 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val DEFAULT_PARALLELISM = buildConf("spark.sql.default.parallelism")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spark.sql.default.parallelism -> spark.sql.sessionLocalDefaultParallelism?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Em.. is it better to keep similar with spark.default.parallelism? so we can set this config easy. sessionLocalDefaultParallelism seems complex.

.doc("The session-local default number of partitions and this value is widely used " +
"inside physical plans. If not set, the physical plans refer to " +
"`spark.default.parallelism` instead.")
.version("3.1.0")
.intConf
.checkValue(_ > 0, "The value of spark.sql.default.parallelism must be positive")
.createOptional

val SHUFFLE_PARTITIONS = buildConf("spark.sql.shuffle.partitions")
.doc("The default number of partitions to use when shuffling data for joins or aggregations. " +
"Note: For structured streaming, this configuration cannot be changed between query " +
Expand Down Expand Up @@ -2784,6 +2793,8 @@ class SQLConf extends Serializable with Logging {

def cacheVectorizedReaderEnabled: Boolean = getConf(CACHE_VECTORIZED_READER_ENABLED)

def defaultParallelism: Option[Int] = getConf(DEFAULT_PARALLELISM)

def defaultNumShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS)

def numShufflePartitions: Int = {
Expand Down
13 changes: 11 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,15 @@ class SparkSession private(
*/
@transient lazy val conf: RuntimeConfig = new RuntimeConfig(sessionState.conf)

/**
* Same as `spark.default.parallelism`, can be isolated across sessions.
*
* @since 3.1.0
*/
def defaultParallelism: Int = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to not have this API, as SparkSession should provide high-level logical APIs, not physical ones.

sessionState.conf.defaultParallelism.getOrElse(sparkContext.defaultParallelism)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we add a config, whose only usage is to let users get the config value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I said above. If add this config, I will move the exists defaultParallelism which in sql module follow up. e.g. FilePartition.maxSplitBytes()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just do this in this pr ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please do, otherwise it's a useless config

}

/**
* An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s
* that listen for execution metrics.
Expand Down Expand Up @@ -513,7 +522,7 @@ class SparkSession private(
* @since 2.0.0
*/
def range(start: Long, end: Long): Dataset[java.lang.Long] = {
range(start, end, step = 1, numPartitions = sparkContext.defaultParallelism)
range(start, end, step = 1, numPartitions = defaultParallelism)
}

/**
Expand All @@ -523,7 +532,7 @@ class SparkSession private(
* @since 2.0.0
*/
def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = {
range(start, end, step, numPartitions = sparkContext.defaultParallelism)
range(start, end, step, numPartitions = defaultParallelism)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ case class LocalTableScanExec(
if (rows.isEmpty) {
sqlContext.sparkContext.emptyRDD
} else {
val numSlices = math.min(unsafeRows.length, sqlContext.sparkContext.defaultParallelism)
val numSlices = math.min(unsafeRows.length, sqlContext.sparkSession.defaultParallelism)
sqlContext.sparkContext.parallelize(unsafeRows, numSlices)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ case class CoalesceShufflePartitions(session: SparkSession) extends Rule[SparkPl
// We fall back to Spark default parallelism if the minimum number of coalesced partitions
// is not set, so to avoid perf regressions compared to no coalescing.
val minPartitionNum = conf.getConf(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM)
.getOrElse(session.sparkContext.defaultParallelism)
.getOrElse(session.defaultParallelism)
val partitionSpecs = ShufflePartitionsUtil.coalescePartitions(
validMetrics.toArray,
advisoryTargetSize = conf.getConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
val start: Long = range.start
val end: Long = range.end
val step: Long = range.step
val numSlices: Int = range.numSlices.getOrElse(sparkContext.defaultParallelism)
val numSlices: Int = range.numSlices.getOrElse(sqlContext.sparkSession.defaultParallelism)
val numElements: BigInt = range.numElements

override val output: Seq[Attribute] = range.output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ case class AlterTableRecoverPartitionsCommand(
// Set the number of parallelism to prevent following file listing from generating many tasks
// in case of large #defaultParallelism.
val numParallelism = Math.min(serializedPaths.length,
Math.min(spark.sparkContext.defaultParallelism, 10000))
Math.min(spark.defaultParallelism, 10000))
// gather the fast stats for all the partitions otherwise Hive metastore will list all the
// files for all the new partitions in sequential way, which is super slow.
logInfo(s"Gather the fast stats in parallel using $numParallelism tasks.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ object FilePartition extends Logging {
selectedPartitions: Seq[PartitionDirectory]): Long = {
val defaultMaxSplitBytes = sparkSession.sessionState.conf.filesMaxPartitionBytes
val openCostInBytes = sparkSession.sessionState.conf.filesOpenCostInBytes
val defaultParallelism = sparkSession.sparkContext.defaultParallelism
val defaultParallelism = sparkSession.defaultParallelism
val totalBytes = selectedPartitions.flatMap(_.files.map(_.getLen + openCostInBytes)).sum
val bytesPerCore = totalBytes / defaultParallelism

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ object SchemaMergeUtils extends Logging {
// Set the number of partitions to prevent following schema reads from generating many tasks
// in case of a small number of orc files.
val numParallelism = Math.min(Math.max(partialFileStatusInfo.size, 1),
sparkSession.sparkContext.defaultParallelism)
sparkSession.defaultParallelism)

val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class RateStreamProvider extends SimpleTableProvider with DataSourceRegister {
}

val numPartitions = options.getInt(
NUM_PARTITIONS, SparkSession.active.sparkContext.defaultParallelism)
NUM_PARTITIONS, SparkSession.active.defaultParallelism)
if (numPartitions <= 0) {
throw new IllegalArgumentException(
s"Invalid value '$numPartitions'. The option 'numPartitions' must be positive")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class TextSocketSourceProvider extends SimpleTableProvider with DataSourceRegist
new TextSocketTable(
options.get("host"),
options.getInt("port", -1),
options.getInt("numPartitions", SparkSession.active.sparkContext.defaultParallelism),
options.getInt("numPartitions", SparkSession.active.defaultParallelism),
options.getBoolean("includeTimestamp", false))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.util.QueryExecutionListener

class SessionStateSuite extends SparkFunSuite {
Expand Down Expand Up @@ -239,4 +240,15 @@ class SessionStateSuite extends SparkFunSuite {
activeSession.conf.unset(key)
}
}

test("add spark.sql.default.parallelism in SQLConf") {
val key = SQLConf.DEFAULT_PARALLELISM.key
try {
assert(activeSession.defaultParallelism == activeSession.sparkContext.defaultParallelism)
activeSession.conf.set(key, "1")
assert(activeSession.defaultParallelism == 1)
} finally {
activeSession.conf.unset(key)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1307,7 +1307,7 @@ class FakeDefaultSource extends FakeSource {
startOffset,
end.asInstanceOf[LongOffset].offset + 1,
1,
Some(spark.sparkSession.sparkContext.defaultParallelism),
Some(spark.sparkSession.defaultParallelism),
isStreaming = true),
Encoders.LONG)
ds.toDF("a")
Expand Down