Skip to content
15 changes: 15 additions & 0 deletions core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.util

import java.util.concurrent._
import java.util.concurrent.{Future => JFuture}
import java.util.concurrent.locks.ReentrantLock

import scala.concurrent.{Awaitable, ExecutionContext, ExecutionContextExecutor, Future}
Expand Down Expand Up @@ -304,6 +305,20 @@ private[spark] object ThreadUtils {
}
// scalastyle:on awaitresult

@throws(classOf[SparkException])
def awaitResult[T](future: JFuture[T], atMost: Duration): T = {
try {
atMost match {
case Duration.Inf => future.get()
case _ => future.get(atMost._1, atMost._2)
}
} catch {
case NonFatal(t)
if !t.isInstanceOf[TimeoutException] && !t.isInstanceOf[RpcAbortException] =>
throw new SparkException("Exception thrown in awaitResult: ", t)
}
}

// scalastyle:off awaitready
/**
* Preferred alternative to `Await.ready()`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@

package org.apache.spark.sql.execution

import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.{ConcurrentHashMap, ExecutorService, Future => JFuture}
import java.util.concurrent.atomic.AtomicLong

import scala.concurrent.{ExecutionContext, Future}

import org.apache.spark.SparkContext
import org.apache.spark.internal.config.Tests.IS_TESTING
import org.apache.spark.sql.SparkSession
Expand Down Expand Up @@ -172,11 +170,11 @@ object SQLExecution {
* SparkContext local properties are forwarded to execution thread
*/
def withThreadLocalCaptured[T](
sparkSession: SparkSession, exec: ExecutionContext)(body: => T): Future[T] = {
sparkSession: SparkSession, exec: ExecutorService) (body: => T): JFuture[T] = {
val activeSession = sparkSession
val sc = sparkSession.sparkContext
val localProps = Utils.cloneProperties(sc.getLocalProperties)
Future {
exec.submit(() => {
val originalSession = SparkSession.getActiveSession
val originalLocalProps = sc.getLocalProperties
SparkSession.setActiveSession(activeSession)
Expand All @@ -190,6 +188,6 @@ object SQLExecution {
SparkSession.clearActiveSession()
}
res
}(exec)
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@

package org.apache.spark.sql.execution

import java.util.concurrent.{Future => JFuture}
import java.util.concurrent.TimeUnit._

import scala.collection.mutable
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.{ExecutionContext}
import scala.concurrent.duration.Duration

import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
Expand Down Expand Up @@ -746,7 +747,7 @@ case class SubqueryExec(name: String, child: SparkPlan)
"collectTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to collect"))

@transient
private lazy val relationFuture: Future[Array[InternalRow]] = {
private lazy val relationFuture: JFuture[Array[InternalRow]] = {
// relationFuture is used in "doExecute". Therefore we can get the execution id correctly here.
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
SQLExecution.withThreadLocalCaptured[Array[InternalRow]](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.joins.HashedRelation
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.util.{SparkFatalException, ThreadUtils}
import org.apache.spark.util.{SparkFatalException, ThreadUtils, Utils}

/**
* A [[BroadcastExchangeExec]] collects, transforms and finally broadcasts the result of
Expand Down Expand Up @@ -73,13 +73,8 @@ case class BroadcastExchangeExec(

@transient
private[sql] lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
// relationFuture is used in "doExecute". Therefore we can get the execution id correctly here.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I usually don't like such changes. Wonder if we can keep the indentation same and make it easier to track the history of commits.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HyukjinKwon I have tried to adjust the indent so that diff seems to show only lines i have modified. Does it seem ok now.?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indentation is wrong now although the diff is smaller. how about

private[sql] lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
  SQLExecution.withThreadLocalCaptured ... {
    doBroadcast(..)
  }
}

private def doBroadcast ... {
  // original code
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan updated as per your suggestion, but still it changes indent than original. Is it ok.?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah sorry I misread the code. Seems we can't avoid changing the indentation as it was so nested before. I'm OK with your original code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reverted

val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
val task = new Callable[broadcast.Broadcast[Any]]() {
override def call(): broadcast.Broadcast[Any] = {
// This will run in another thread. Set the execution id so that we can connect these jobs
// with the correct execution.
SQLExecution.withExecutionId(sqlContext.sparkSession, executionId) {
SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]](
sqlContext.sparkSession, BroadcastExchangeExec.executionContext) {
try {
// Setup a job group here so later it may get cancelled by groupId if necessary.
sparkContext.setJobGroup(runId.toString, s"broadcast exchange (runId $runId)",
Expand Down Expand Up @@ -121,7 +116,7 @@ case class BroadcastExchangeExec(
val broadcasted = sparkContext.broadcast(relation)
longMetric("broadcastTime") += NANOSECONDS.toMillis(
System.nanoTime() - beforeBroadcast)

val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
promise.success(broadcasted)
broadcasted
Expand All @@ -146,10 +141,7 @@ case class BroadcastExchangeExec(
promise.failure(e)
throw e
}
}
}
}
BroadcastExchangeExec.executionContext.submit[broadcast.Broadcast[Any]](task)
}

override protected def doPrepare(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,34 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils {
}
}
}

test("SPARK-22590 propagate local properties to broadcast execution thread") {
withSQLConf(StaticSQLConf.BROADCAST_EXCHANGE_MAX_THREAD_THRESHOLD.key -> "1") {
val df1 = Seq(true).toDF()
val confKey = "spark.sql.y"
val confValue1 = UUID.randomUUID().toString()
val confValue2 = UUID.randomUUID().toString()

def generateBroadcastDataFrame(confKey: String, confValue: String): Dataset[Boolean] = {
val df = spark.range(1).mapPartitions { _ =>
Iterator(TaskContext.get.getLocalProperty(confKey) == confValue)
}
df.hint("broadcast")
}

// set local propert and assert
val df2 = generateBroadcastDataFrame(confKey, confValue1)
spark.sparkContext.setLocalProperty(confKey, confValue1)
val checks = df1.join(df2).collect()
assert(checks.forall(_.toSeq == Seq(true, true)))

// change local property and re-assert
val df3 = generateBroadcastDataFrame(confKey, confValue2)
spark.sparkContext.setLocalProperty(confKey, confValue2)
val checks2 = df1.join(df3).collect()
assert(checks2.forall(_.toSeq == Seq(true, true)))
}
}
}

case class SQLConfAssertPlan(confToCheck: Seq[(String, String)]) extends LeafExecNode {
Expand Down