diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 296076d67eb39..ffc96c56bbe73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -142,9 +142,19 @@ object SQLExecution { val sc = sparkSession.sparkContext val localProps = Utils.cloneProperties(sc.getLocalProperties) Future { + val originalSession = SparkSession.getActiveSession + val originalLocalProps = sc.getLocalProperties SparkSession.setActiveSession(activeSession) sc.setLocalProperties(localProps) - body + val res = body + // reset active session and local props. + sc.setLocalProperties(originalLocalProps) + if (originalSession.nonEmpty) { + SparkSession.setActiveSession(originalSession.get) + } else { + SparkSession.clearActiveSession() + } + res }(exec) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala index 2233f4d2a1c04..7059ce8faf927 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.internal +import java.util.UUID + +import org.scalatest.Assertions._ + import org.apache.spark.{SparkException, SparkFunSuite, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Dataset, SparkSession} @@ -148,14 +152,14 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { } // set local configuration and assert - val confValue1 = "e" + val confValue1 = UUID.randomUUID().toString() createDataframe(confKey, confValue1).createOrReplaceTempView("m") spark.sparkContext.setLocalProperty(confKey, confValue1) val result1 = sql("SELECT value, (SELECT MAX(*) FROM m) x FROM l").collect assert(result1.forall(_.getBoolean(1))) // change the conf value and assert again - val confValue2 = "f" + val confValue2 = UUID.randomUUID().toString() createDataframe(confKey, confValue2).createOrReplaceTempView("n") spark.sparkContext.setLocalProperty(confKey, confValue2) val result2 = sql("SELECT value, (SELECT MAX(*) FROM n) x FROM l").collect