diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 67803ad76d5e5..ae010f984f791 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -323,26 +323,28 @@ abstract class StreamExecution( startLatch.countDown() // While active, repeatedly attempt to run batches. - SparkSession.setActiveSession(sparkSession) - - updateStatusMessage("Initializing sources") - // force initialization of the logical plan so that the sources can be created - logicalPlan - - // Adaptive execution can change num shuffle partitions, disallow - sparkSessionForStream.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false") - // Disable cost-based join optimization as we do not want stateful operations to be rearranged - sparkSessionForStream.conf.set(SQLConf.CBO_ENABLED.key, "false") - offsetSeqMetadata = OffsetSeqMetadata( - batchWatermarkMs = 0, batchTimestampMs = 0, sparkSessionForStream.conf) - - if (state.compareAndSet(INITIALIZING, ACTIVE)) { - // Unblock `awaitInitialization` - initializationLatch.countDown() - runActivatedStream(sparkSessionForStream) - updateStatusMessage("Stopped") - } else { - // `stop()` is already called. Let `finally` finish the cleanup. + sparkSessionForStream.withActive { + // Adaptive execution can change num shuffle partitions, disallow + sparkSessionForStream.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false") + // Disable cost-based join optimization as we do not want stateful operations + // to be rearranged + sparkSessionForStream.conf.set(SQLConf.CBO_ENABLED.key, "false") + + updateStatusMessage("Initializing sources") + // force initialization of the logical plan so that the sources can be created + logicalPlan + + offsetSeqMetadata = OffsetSeqMetadata( + batchWatermarkMs = 0, batchTimestampMs = 0, sparkSessionForStream.conf) + + if (state.compareAndSet(INITIALIZING, ACTIVE)) { + // Unblock `awaitInitialization` + initializationLatch.countDown() + runActivatedStream(sparkSessionForStream) + updateStatusMessage("Stopped") + } else { + // `stop()` is already called. Let `finally` finish the cleanup. + } } } catch { case e if isInterruptedByStop(e, sparkSession.sparkContext) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index c4e43d24b0b82..fb6922aca19a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -34,7 +34,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkContext, TaskContext, TestUtils} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.plans.logical.Range +import org.apache.spark.sql.catalyst.plans.logical.{Range, RepartitionByExpression} import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, StreamingRelationV2} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.{LocalLimitExec, SimpleMode, SparkPlan} @@ -1266,6 +1266,37 @@ class StreamSuite extends StreamTest { } } } + + test("SPARK-34482: correct active SparkSession for logicalPlan") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10") { + val df = spark.readStream.format(classOf[FakeDefaultSource].getName).load() + var query: StreamExecution = null + try { + query = + df.repartition($"a") + .writeStream + .format("memory") + .queryName("memory") + .start() + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery + query.awaitInitialization(streamingTimeout.toMillis) + val plan = query.logicalPlan + val numPartition = plan + .find { _.isInstanceOf[RepartitionByExpression] } + .map(_.asInstanceOf[RepartitionByExpression].numPartitions) + // Before the fix of SPARK-34482, the numPartition is the value of + // `COALESCE_PARTITIONS_INITIAL_PARTITION_NUM`. + assert(numPartition.get === spark.sqlContext.conf.getConf(SQLConf.SHUFFLE_PARTITIONS)) + } finally { + if (query != null) { + query.stop() + } + } + } + } } abstract class FakeSource extends StreamSourceProvider {