From e9ae1e895272b54e7418d4d841f1ef6f7b54ca43 Mon Sep 17 00:00:00 2001 From: Yu Zhong Date: Tue, 29 Dec 2020 17:31:04 +0800 Subject: [PATCH] SPARK-33933: materialize BroadcastQueryState first to avoid broadcast timeout in AQE --- .../adaptive/AdaptiveSparkPlanExec.scala | 10 +++++++- .../execution/joins/BroadcastJoinSuite.scala | 23 +++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 187827ca6005e..deace5be1ef0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -171,8 +171,16 @@ case class AdaptiveSparkPlanExec( stagesToReplace = result.newStages ++ stagesToReplace executionId.foreach(onUpdatePlan(_, result.newStages.map(_.plan))) + // We should materialize BroadcastQueryState first to avoid broadcast timeout + // Sort the new stages by class type to make sure BroadcastQueryState precede others + val reorderedNewStages = result.newStages + .sortWith { + case (_: BroadcastQueryStageExec, _: BroadcastQueryStageExec) => false + case (_: BroadcastQueryStageExec, _) => true + case _ => false + } // Start materialization of all new stages and fail fast if any stages failed eagerly - result.newStages.foreach { stage => + reorderedNewStages.foreach { stage => try { stage.materialize().onComplete { res => if (res.isSuccess) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index ef0a596f21104..452bf36d085af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -421,6 +421,29 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils assert(e.getMessage.contains(s"Could not execute broadcast in $timeout secs.")) } } + + test("SPARK-33933: AQE broadcast should not timeout with slow map tasks") { + val broadcastTimeoutInSec = 5 + val df = spark.sparkContext.parallelize(Range(0, 1000), 1000) + .flatMap(x => { + Thread.sleep(20) + for (i <- Range(0, 100)) yield (x % 26, x % 10) + }).toDF("index", "pv") + val dim = Range(0, 26).map(x => (x, ('a' + x).toChar.toString)) + .toDF("index", "name") + val testDf = df.groupBy("index") + .agg(sum($"pv").alias("pv")) + .join(dim, Seq("index")) + withSQLConf(SQLConf.BROADCAST_TIMEOUT.key -> broadcastTimeoutInSec.toString, + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val startTime = System.currentTimeMillis() + val result = testDf.collect() + val queryTime = System.currentTimeMillis() - startTime + assert(result.length == 26) + // make sure the execution time is large enough + assert(queryTime > (broadcastTimeoutInSec + 1) * 1000) + } + } } class BroadcastJoinSuite extends BroadcastJoinSuiteBase with DisableAdaptiveExecutionSuite