diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index de39e4b410f25..e7872bb9cb6b0 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import java.util.concurrent._ +import java.util.concurrent.{Future => JFuture} import java.util.concurrent.locks.ReentrantLock import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor, Future} @@ -304,6 +305,22 @@ private[spark] object ThreadUtils { } // scalastyle:on awaitresult + @throws(classOf[SparkException]) + def awaitResult[T](future: JFuture[T], atMost: Duration): T = { + try { + atMost match { + case Duration.Inf => future.get() + case _ => future.get(atMost._1, atMost._2) + } + } catch { + case e: SparkFatalException => + throw e.throwable + case NonFatal(t) + if !t.isInstanceOf[TimeoutException] && !t.isInstanceOf[RpcAbortException] => + throw new SparkException("Exception thrown in awaitResult: ", t) + } + } + // scalastyle:off awaitready /** * Preferred alternative to `Await.ready()`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 59c503e372535..5e4f30a5edaf1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -17,11 +17,9 @@ package org.apache.spark.sql.execution -import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.{ConcurrentHashMap, ExecutorService, Future => JFuture} import java.util.concurrent.atomic.AtomicLong -import scala.concurrent.{ExecutionContext, Future} - import org.apache.spark.SparkContext import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.sql.SparkSession @@ -172,11 +170,11 @@ object SQLExecution { * SparkContext local properties are forwarded to execution thread */ def withThreadLocalCaptured[T]( - sparkSession: SparkSession, exec: ExecutionContext)(body: => T): Future[T] = { + sparkSession: SparkSession, exec: ExecutorService) (body: => T): JFuture[T] = { val activeSession = sparkSession val sc = sparkSession.sparkContext val localProps = Utils.cloneProperties(sc.getLocalProperties) - Future { + exec.submit(() => { val originalSession = SparkSession.getActiveSession val originalLocalProps = sc.getLocalProperties SparkSession.setActiveSession(activeSession) @@ -190,6 +188,6 @@ object SQLExecution { SparkSession.clearActiveSession() } res - }(exec) + }) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index c35c48496e1c9..4e65bbf75282f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution +import java.util.concurrent.{Future => JFuture} import java.util.concurrent.TimeUnit._ import scala.collection.mutable -import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.{ExecutionContext} import scala.concurrent.duration.Duration import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} @@ -746,7 +747,7 @@ case class SubqueryExec(name: String, child: SparkPlan) "collectTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to collect")) @transient - private lazy val relationFuture: Future[Array[InternalRow]] = { + private lazy val relationFuture: JFuture[Array[InternalRow]] = { // relationFuture is used in "doExecute". Therefore we can get the execution id correctly here. val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) SQLExecution.withThreadLocalCaptured[Array[InternalRow]]( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 36f0d173cd0b0..65e6b7c2f0fba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.joins.HashedRelation import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} -import org.apache.spark.util.{SparkFatalException, ThreadUtils} +import org.apache.spark.util.{SparkFatalException, ThreadUtils, Utils} /** * A [[BroadcastExchangeExec]] collects, transforms and finally broadcasts the result of @@ -73,13 +73,8 @@ case class BroadcastExchangeExec( @transient private[sql] lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { - // relationFuture is used in "doExecute". Therefore we can get the execution id correctly here. - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - val task = new Callable[broadcast.Broadcast[Any]]() { - override def call(): broadcast.Broadcast[Any] = { - // This will run in another thread. Set the execution id so that we can connect these jobs - // with the correct execution. - SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) { + SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]]( + sqlContext.sparkSession, BroadcastExchangeExec.executionContext) { try { // Setup a job group here so later it may get cancelled by groupId if necessary. sparkContext.setJobGroup(runId.toString, s"broadcast exchange (runId $runId)", @@ -121,7 +116,7 @@ case class BroadcastExchangeExec( val broadcasted = sparkContext.broadcast(relation) longMetric("broadcastTime") += NANOSECONDS.toMillis( System.nanoTime() - beforeBroadcast) - + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) promise.success(broadcasted) broadcasted @@ -146,10 +141,7 @@ case class BroadcastExchangeExec( promise.failure(e) throw e } - } - } } - BroadcastExchangeExec.executionContext.submit[broadcast.Broadcast[Any]](task) } override protected def doPrepare(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala index 46d0c64592a00..888772c35d0ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -159,6 +159,34 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { } } } + + test("SPARK-22590 propagate local properties to broadcast execution thread") { + withSQLConf(StaticSQLConf.BROADCAST_EXCHANGE_MAX_THREAD_THRESHOLD.key -> "1") { + val df1 = Seq(true).toDF() + val confKey = "spark.sql.y" + val confValue1 = UUID.randomUUID().toString() + val confValue2 = UUID.randomUUID().toString() + + def generateBroadcastDataFrame(confKey: String, confValue: String): Dataset[Boolean] = { + val df = spark.range(1).mapPartitions { _ => + Iterator(TaskContext.get.getLocalProperty(confKey) == confValue) + } + df.hint("broadcast") + } + + // set local propert and assert + val df2 = generateBroadcastDataFrame(confKey, confValue1) + spark.sparkContext.setLocalProperty(confKey, confValue1) + val checks = df1.join(df2).collect() + assert(checks.forall(_.toSeq == Seq(true, true))) + + // change local property and re-assert + val df3 = generateBroadcastDataFrame(confKey, confValue2) + spark.sparkContext.setLocalProperty(confKey, confValue2) + val checks2 = df1.join(df3).collect() + assert(checks2.forall(_.toSeq == Seq(true, true))) + } + } } case class SQLConfAssertPlan(confToCheck: Seq[(String, String)]) extends LeafExecNode {