diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 187827ca6005e..06bca33764c93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -171,8 +171,20 @@ case class AdaptiveSparkPlanExec( stagesToReplace = result.newStages ++ stagesToReplace executionId.foreach(onUpdatePlan(_, result.newStages.map(_.plan))) + // SPARK-33933: we should submit tasks of broadcast stages first, to avoid waiting + // for tasks to be scheduled and leading to broadcast timeout. + // This partial fix only guarantees the start of materialization for BroadcastQueryStage + // is prior to others, but because the submission of collect job for broadcasting is + // running in another thread, the issue is not completely resolved. + val reorderedNewStages = result.newStages + .sortWith { + case (_: BroadcastQueryStageExec, _: BroadcastQueryStageExec) => false + case (_: BroadcastQueryStageExec, _) => true + case _ => false + } + // Start materialization of all new stages and fail fast if any stages failed eagerly - result.newStages.foreach { stage => + reorderedNewStages.foreach { stage => try { stage.materialize().onComplete { res => if (res.isSuccess) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala index 74fe1eaab6f64..91a91a656a346 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala @@ -76,6 +76,7 @@ abstract class QueryStageExec extends LeafExecNode { * stage is ready. */ final def materialize(): Future[Any] = executeQuery { + logDebug(s"Materialize query stage ${this.getClass.getSimpleName}: $id") doMaterialize() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 6d97a6bb47d0f..dca4d88f840fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -903,4 +903,27 @@ class AdaptiveQueryExecSuite } } } + + test("SPARK-33933: Materialize BroadcastQueryStage first in AQE") { + val testAppender = new LogAppender("aqe query stage materialization order test") + val df = spark.range(1000).select($"id" % 26, $"id" % 10) + .toDF("index", "pv") + val dim = Range(0, 26).map(x => (x, ('a' + x).toChar.toString)) + .toDF("index", "name") + val testDf = df.groupBy("index") + .agg(sum($"pv").alias("pv")) + .join(dim, Seq("index")) + withLogAppender(testAppender, level = Some(Level.DEBUG)) { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val result = testDf.collect() + assert(result.length == 26) + } + } + val materializeLogs = testAppender.loggingEvents + .map(_.getRenderedMessage) + .filter(_.startsWith("Materialize query stage")) + .toArray + assert(materializeLogs(0).startsWith("Materialize query stage BroadcastQueryStageExec")) + assert(materializeLogs(1).startsWith("Materialize query stage ShuffleQueryStageExec")) + } }