From 03f88dd4a24b547ba8ca162ef189d29f3fd1caf7 Mon Sep 17 00:00:00 2001 From: Yu Zhong Date: Sun, 16 May 2021 23:10:12 +0800 Subject: [PATCH 1/5] SPARK-35414: Submit broadcast job first to avoid broadcast timeout in AQE --- .../spark/sql/execution/SparkPlan.scala | 18 ++++++--- .../exchange/BroadcastExchangeExec.scala | 12 ++++-- .../adaptive/AdaptiveQueryExecSuite.scala | 38 ++++++++++++++++++- 3 files changed, 58 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 40bf094856bca..b46ddc2eaf121 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -21,6 +21,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, ListBuffer} +import scala.concurrent.{ExecutionContext, Future} import org.apache.spark.{broadcast, SparkEnv} import org.apache.spark.internal.Logging @@ -392,11 +393,18 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ results.toArray } - private[spark] def executeCollectIterator(): (Long, Iterator[InternalRow]) = { - val countsAndBytes = getByteArrayRdd().collect() - val total = countsAndBytes.map(_._1).sum - val rows = countsAndBytes.iterator.flatMap(countAndBytes => decodeUnsafeRows(countAndBytes._2)) - (total, rows) + /** + * Runs this query in async way and return the future of collect result. + */ + private[spark] def executeCollectIteratorFuture()( + implicit executor: ExecutionContext): Future[(Long, Iterator[InternalRow])] = { + val future = getByteArrayRdd().collectAsync() + future.map(countsAndBytes => { + val total = countsAndBytes.map(_._1).sum + val rows = countsAndBytes.iterator + .flatMap(countAndBytes => decodeUnsafeRows(countAndBytes._2)) + (total, rows) + }) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 531960a2477ed..5acb43d02a982 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -21,7 +21,7 @@ import java.util.UUID import java.util.concurrent._ import scala.concurrent.{ExecutionContext, Promise} -import scala.concurrent.duration.NANOSECONDS +import scala.concurrent.duration.{Duration, NANOSECONDS} import scala.util.control.NonFatal import org.apache.spark.{broadcast, SparkException} @@ -107,15 +107,19 @@ case class BroadcastExchangeExec( @transient override lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { + val beforeCollect = System.nanoTime() + // SPARK-35414: use executeCollectIteratorFuture() to submit a job before get the relationFuture + // This can ensure the broadcast job is submitted before shuffle map job in AQE + val collectFuture = child.executeCollectIteratorFuture()( + BroadcastExchangeExec.executionContext) + SQLExecution.withThreadLocalCaptured[broadcast.Broadcast[Any]]( sqlContext.sparkSession, BroadcastExchangeExec.executionContext) { try { // Setup a job group here so later it may get cancelled by groupId if necessary. sparkContext.setJobGroup(runId.toString, s"broadcast exchange (runId $runId)", interruptOnCancel = true) - val beforeCollect = System.nanoTime() - // Use executeCollect/executeCollectIterator to avoid conversion to Scala types - val (numRows, input) = child.executeCollectIterator() + val (numRows, input) = ThreadUtils.awaitResult(collectFuture, Duration.Inf) longMetric("numOutputRows") += numRows if (numRows >= MAX_BROADCAST_TABLE_ROWS) { throw new SparkException( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 454d3aa148a44..72d184ff9ef24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -23,7 +23,7 @@ import java.net.URI import org.apache.log4j.Level import org.scalatest.PrivateMethodTester -import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} +import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart, SparkListenerStageSubmitted, StageInfo} import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} @@ -1641,4 +1641,40 @@ class AdaptiveQueryExecSuite } } } + + test("SPARK-35414: Submit broadcast job first to avoid broadcast timeout in AQE") { + val broadcastTimeoutInSec = 10 + val shuffleMapTaskParallelism = 10 + + val df = spark.sparkContext.parallelize(Range(0, 10), shuffleMapTaskParallelism) + .flatMap(x => { + Thread.sleep(10) + for (i <- Range(0, 100)) yield (x % 26, x % 10) + }).toDF("index", "pv") + val dim = Range(0, 26).map(x => (x, ('a' + x).toChar.toString)) + .toDF("index", "name") + val testDf = df.groupBy("index") + .agg(sum($"pv").alias("pv")) + .join(dim, Seq("index")) + + val stageInfos = scala.collection.mutable.ArrayBuffer[StageInfo]() + val listener = new SparkListener { + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { + stageInfos += stageSubmitted.stageInfo + } + } + spark.sparkContext.addSparkListener(listener) + + withSQLConf(SQLConf.BROADCAST_TIMEOUT.key -> broadcastTimeoutInSec.toString, + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val result = testDf.collect() + assert(result.length == 26) + val sortedStageInfos = stageInfos.sortBy(_.submissionTime) + assert(sortedStageInfos.size > 2) + // this is broadcast stage + assert(sortedStageInfos(0).numTasks == 1) + // this is shuffle map stage + assert(sortedStageInfos(1).numTasks == shuffleMapTaskParallelism) + } + } } From 4677620bb1db90332b2ae2b7c9c84e3d5a7f9053 Mon Sep 17 00:00:00 2001 From: Yu Zhong Date: Mon, 17 May 2021 10:33:51 +0800 Subject: [PATCH 2/5] fix UT --- .../sql/execution/adaptive/AdaptiveQueryExecSuite.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 72d184ff9ef24..5ed2b97cbd816 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -1643,10 +1643,9 @@ class AdaptiveQueryExecSuite } test("SPARK-35414: Submit broadcast job first to avoid broadcast timeout in AQE") { - val broadcastTimeoutInSec = 10 val shuffleMapTaskParallelism = 10 - val df = spark.sparkContext.parallelize(Range(0, 10), shuffleMapTaskParallelism) + val df = spark.sparkContext.parallelize(Range(0, 26), shuffleMapTaskParallelism) .flatMap(x => { Thread.sleep(10) for (i <- Range(0, 100)) yield (x % 26, x % 10) @@ -1665,8 +1664,7 @@ class AdaptiveQueryExecSuite } spark.sparkContext.addSparkListener(listener) - withSQLConf(SQLConf.BROADCAST_TIMEOUT.key -> broadcastTimeoutInSec.toString, - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { val result = testDf.collect() assert(result.length == 26) val sortedStageInfos = stageInfos.sortBy(_.submissionTime) From ea0dbf8e26f2779dee8da9c5c29c7d37a726c5f5 Mon Sep 17 00:00:00 2001 From: Yu Zhong Date: Tue, 18 May 2021 15:44:31 +0800 Subject: [PATCH 3/5] fix UT --- .../spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 5ed2b97cbd816..ff4d4ca8a05b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -1652,6 +1652,7 @@ class AdaptiveQueryExecSuite }).toDF("index", "pv") val dim = Range(0, 26).map(x => (x, ('a' + x).toChar.toString)) .toDF("index", "name") + .coalesce(1) val testDf = df.groupBy("index") .agg(sum($"pv").alias("pv")) .join(dim, Seq("index")) From fbf8b8cf101dce393cdfd3f8ae74eb882aafb949 Mon Sep 17 00:00:00 2001 From: Yu Zhong Date: Tue, 15 Jun 2021 16:09:20 +0800 Subject: [PATCH 4/5] fix bug to pass cancel the job group UT --- .../sql/execution/exchange/BroadcastExchangeExec.scala | 3 +++ .../sql/execution/adaptive/AdaptiveQueryExecSuite.scala | 8 +++----- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 5acb43d02a982..a6d3ce076f4f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -107,6 +107,9 @@ case class BroadcastExchangeExec( @transient override lazy val relationFuture: Future[broadcast.Broadcast[Any]] = { + // Setup a job group here so later it may get cancelled by groupId if necessary. + sparkContext.setJobGroup(runId.toString, s"broadcast exchange (runId $runId)", + interruptOnCancel = true) val beforeCollect = System.nanoTime() // SPARK-35414: use executeCollectIteratorFuture() to submit a job before get the relationFuture // This can ensure the broadcast job is submitted before shuffle map job in AQE diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index ff4d4ca8a05b7..8cbc6ffb49194 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -1645,11 +1645,9 @@ class AdaptiveQueryExecSuite test("SPARK-35414: Submit broadcast job first to avoid broadcast timeout in AQE") { val shuffleMapTaskParallelism = 10 - val df = spark.sparkContext.parallelize(Range(0, 26), shuffleMapTaskParallelism) - .flatMap(x => { - Thread.sleep(10) - for (i <- Range(0, 100)) yield (x % 26, x % 10) - }).toDF("index", "pv") + val df = spark.range(0, 1000, 1, shuffleMapTaskParallelism) + .select($"id" % 26, $"id" % 10) + .toDF("index", "pv") val dim = Range(0, 26).map(x => (x, ('a' + x).toChar.toString)) .toDF("index", "name") .coalesce(1) From 3bf85714608b3b9d1934ff32bdea65210b977adf Mon Sep 17 00:00:00 2001 From: Yu Zhong Date: Tue, 15 Jun 2021 21:11:27 +0800 Subject: [PATCH 5/5] remove whitespace at end of line --- .../spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index c8af6eaf42522..8f177f5f52313 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -1817,7 +1817,7 @@ class AdaptiveQueryExecSuite } } } - + test("SPARK-35414: Submit broadcast job first to avoid broadcast timeout in AQE") { val shuffleMapTaskParallelism = 10