diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 187827ca6005e..c5558665dcc1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -171,8 +171,17 @@ case class AdaptiveSparkPlanExec( stagesToReplace = result.newStages ++ stagesToReplace executionId.foreach(onUpdatePlan(_, result.newStages.map(_.plan))) + // SPARK-33933: we should submit tasks of broadcast stages first, to avoid waiting + // for tasks to be scheduled and leading to broadcast timeout. + val reorderedNewStages = result.newStages + .sortWith { + case (_: BroadcastQueryStageExec, _: BroadcastQueryStageExec) => false + case (_: BroadcastQueryStageExec, _) => true + case _ => false + } + // Start materialization of all new stages and fail fast if any stages failed eagerly - result.newStages.foreach { stage => + reorderedNewStages.foreach { stage => try { stage.materialize().onComplete { res => if (res.isSuccess) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 6d97a6bb47d0f..f30bf539624fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -903,4 +903,28 @@ class AdaptiveQueryExecSuite } } } + + test("SPARK-33933: AQE broadcast should not timeout with slow map tasks") { + val broadcastTimeoutInSec = 1 + val df = spark.sparkContext.parallelize(Range(0, 100), 100) + .flatMap(x => { + Thread.sleep(20) + for (i <- Range(0, 100)) yield (x % 26, x % 10) + }).toDF("index", "pv") + val dim = Range(0, 26).map(x => (x, ('a' + x).toChar.toString)) + .toDF("index", "name") + val testDf = df.groupBy("index") + .agg(sum($"pv").alias("pv")) + .join(dim, Seq("index")) + withSQLConf(SQLConf.BROADCAST_TIMEOUT.key -> broadcastTimeoutInSec.toString, + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val startTime = System.currentTimeMillis() + val result = testDf.collect() + val queryTime = System.currentTimeMillis() - startTime + assert(result.length == 26) + // make sure the execution time is large enough + assert(queryTime > (broadcastTimeoutInSec + 1) * 1000) + } + } + }