diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 40bf094856bca..b46ddc2eaf121 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -21,6 +21,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, ListBuffer} +import scala.concurrent.{ExecutionContext, Future} import org.apache.spark.{broadcast, SparkEnv} import org.apache.spark.internal.Logging @@ -392,11 +393,18 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ results.toArray } - private[spark] def executeCollectIterator(): (Long, Iterator[InternalRow]) = { - val countsAndBytes = getByteArrayRdd().collect() - val total = countsAndBytes.map(_._1).sum - val rows = countsAndBytes.iterator.flatMap(countAndBytes => decodeUnsafeRows(countAndBytes._2)) - (total, rows) + /** + * Runs this query in async way and return the future of collect result. + */ + private[spark] def executeCollectIteratorFuture()( + implicit executor: ExecutionContext): Future[(Long, Iterator[InternalRow])] = { + val future = getByteArrayRdd().collectAsync() + future.map(countsAndBytes => { + val total = countsAndBytes.map(_._1).sum + val rows = countsAndBytes.iterator + .flatMap(countAndBytes => decodeUnsafeRows(countAndBytes._2)) + (total, rows) + }) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 531960a2477ed..a6d3ce076f4f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -21,7 +21,7 @@ import java.util.UUID import java.util.concurrent._ import scala.concurrent.{ExecutionContext, Promise} -import scala.concurrent.duration.NANOSECONDS +import scala.concurrent.duration.{Duration, NANOSECONDS} import scala.util.control.NonFatal import org.apache.spark.{broadcast, SparkException} @@ -107,15 +107,22 @@ case class BroadcastExchangeExec( @transient override lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { + // Setup a job group here so later it may get cancelled by groupId if necessary. + sparkContext.setJobGroup(runId.toString, s"broadcast exchange (runId $runId)", + interruptOnCancel = true) + val beforeCollect = System.nanoTime() + // SPARK-35414: use executeCollectIteratorFuture() to submit a job before get the relationFuture + // This can ensure the broadcast job is submitted before shuffle map job in AQE + val collectFuture = child.executeCollectIteratorFuture()( + BroadcastExchangeExec.executionContext) + SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]]( sqlContext.sparkSession, BroadcastExchangeExec.executionContext) { try { // Setup a job group here so later it may get cancelled by groupId if necessary. sparkContext.setJobGroup(runId.toString, s"broadcast exchange (runId $runId)", interruptOnCancel = true) - val beforeCollect = System.nanoTime() - // Use executeCollect/executeCollectIterator to avoid conversion to Scala types - val (numRows, input) = child.executeCollectIterator() + val (numRows, input) = ThreadUtils.awaitResult(collectFuture, Duration.Inf) longMetric("numOutputRows") += numRows if (numRows >= MAX_BROADCAST_TABLE_ROWS) { throw new SparkException( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 690fdb1dc15fc..8f177f5f52313 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -23,7 +23,7 @@ import java.net.URI import org.apache.log4j.Level import org.scalatest.PrivateMethodTester -import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} +import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart, SparkListenerStageSubmitted, StageInfo} import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} @@ -1817,4 +1817,37 @@ class AdaptiveQueryExecSuite } } } + + test("SPARK-35414: Submit broadcast job first to avoid broadcast timeout in AQE") { + val shuffleMapTaskParallelism = 10 + + val df = spark.range(0, 1000, 1, shuffleMapTaskParallelism) + .select($"id" % 26, $"id" % 10) + .toDF("index", "pv") + val dim = Range(0, 26).map(x => (x, ('a' + x).toChar.toString)) + .toDF("index", "name") + .coalesce(1) + val testDf = df.groupBy("index") + .agg(sum($"pv").alias("pv")) + .join(dim, Seq("index")) + + val stageInfos = scala.collection.mutable.ArrayBuffer[StageInfo]() + val listener = new SparkListener { + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { + stageInfos += stageSubmitted.stageInfo + } + } + spark.sparkContext.addSparkListener(listener) + + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val result = testDf.collect() + assert(result.length == 26) + val sortedStageInfos = stageInfos.sortBy(_.submissionTime) + assert(sortedStageInfos.size > 2) + // this is broadcast stage + assert(sortedStageInfos(0).numTasks == 1) + // this is shuffle map stage + assert(sortedStageInfos(1).numTasks == shuffleMapTaskParallelism) + } + } }