From 25be218df75a8cc62da746d9a849a6ea9fb7ee67 Mon Sep 17 00:00:00 2001 From: Yu Zhong Date: Sun, 3 Jan 2021 20:41:18 +0800 Subject: [PATCH 1/4] SPARK-33933: materialize BroadcastQueryState first to avoid broadcast timeout in AQE --- .../adaptive/AdaptiveSparkPlanExec.scala | 11 ++++++++- .../execution/joins/BroadcastJoinSuite.scala | 23 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 89d3b53510469..68e21a571c347 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -189,8 +189,17 @@ case class AdaptiveSparkPlanExec( stagesToReplace = result.newStages ++ stagesToReplace executionId.foreach(onUpdatePlan(_, result.newStages.map(_.plan))) + // SPARK-33933: We should materialize BroadcastQueryState first to avoid broadcast timeout + // Sort the new stages by class type to make sure BroadcastQueryState precede others + val reorderedNewStages = result.newStages + .sortWith { + case (_: BroadcastQueryStageExec, _: BroadcastQueryStageExec) => false + case (_: BroadcastQueryStageExec, _) => true + case _ => false + } + // Start materialization of all new stages and fail fast if any stages failed eagerly - result.newStages.foreach { stage => + reorderedNewStages.foreach { stage => try { stage.materialize().onComplete { res => if (res.isSuccess) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 98a1089709b92..31ffbdc59182a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -425,6 +425,29 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils } } + test("SPARK-33933: AQE broadcast should not timeout with slow map tasks") { + val broadcastTimeoutInSec = 5 + val df = spark.sparkContext.parallelize(Range(0, 1000), 1000) + .flatMap(x => { + Thread.sleep(20) + for (i <- Range(0, 100)) yield (x % 26, x % 10) + }).toDF("index", "pv") + val dim = Range(0, 26).map(x => (x, ('a' + x).toChar.toString)) + .toDF("index", "name") + val testDf = df.groupBy("index") + .agg(sum($"pv").alias("pv")) + .join(dim, Seq("index")) + withSQLConf(SQLConf.BROADCAST_TIMEOUT.key -> broadcastTimeoutInSec.toString, + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val startTime = System.currentTimeMillis() + val result = testDf.collect() + val queryTime = System.currentTimeMillis() - startTime + assert(result.length == 26) + // make sure the execution time is large enough + assert(queryTime > (broadcastTimeoutInSec + 1) * 1000) + } + } + test("broadcast join where streamed side's output partitioning is HashPartitioning") { withTable("t1", "t3") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") { From c49065cd3dcc836a17addf5d12aa41bfc451d775 Mon Sep 17 00:00:00 2001 From: Yu Zhong Date: Mon, 4 Jan 2021 15:12:02 +0800 Subject: [PATCH 2/4] update comment --- .../spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 68e21a571c347..aa09f21af19b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -189,8 +189,8 @@ case class AdaptiveSparkPlanExec( stagesToReplace = result.newStages ++ stagesToReplace executionId.foreach(onUpdatePlan(_, result.newStages.map(_.plan))) - // SPARK-33933: We should materialize BroadcastQueryState first to avoid broadcast timeout - // Sort the new stages by class type to make sure BroadcastQueryState precede others + // SPARK-33933: we should submit tasks of broadcast stages first, to avoid waiting + // for tasks to be scheduled and leading to broadcast timeout. val reorderedNewStages = result.newStages .sortWith { case (_: BroadcastQueryStageExec, _: BroadcastQueryStageExec) => false From 83e3e4e3d19676c7996e6ca2ff72cbd40c60d485 Mon Sep 17 00:00:00 2001 From: Yu Zhong Date: Tue, 5 Jan 2021 15:51:31 +0800 Subject: [PATCH 3/4] reduce UT time within 5 sec --- .../apache/spark/sql/execution/joins/BroadcastJoinSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 31ffbdc59182a..8e1db0cfbeed0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -426,8 +426,8 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils } test("SPARK-33933: AQE broadcast should not timeout with slow map tasks") { - val broadcastTimeoutInSec = 5 - val df = spark.sparkContext.parallelize(Range(0, 1000), 1000) + val broadcastTimeoutInSec = 1 + val df = spark.sparkContext.parallelize(Range(0, 100), 100) .flatMap(x => { Thread.sleep(20) for (i <- Range(0, 100)) yield (x % 26, x % 10) From 7cbeb14272833ac5d7e5aecc81f93c2c3a5cbadf Mon Sep 17 00:00:00 2001 From: Yu Zhong Date: Wed, 6 Jan 2021 16:41:00 +0800 Subject: [PATCH 4/4] move UT from BroadcastJoinSuite to AdaptiveQueryExecSuite --- .../adaptive/AdaptiveQueryExecSuite.scala | 24 +++++++++++++++++++ .../execution/joins/BroadcastJoinSuite.scala | 23 ------------------ 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 69f1565c2f8de..75993d49da677 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -1431,4 +1431,28 @@ class AdaptiveQueryExecSuite } } } + + test("SPARK-33933: AQE broadcast should not timeout with slow map tasks") { + val broadcastTimeoutInSec = 1 + val df = spark.sparkContext.parallelize(Range(0, 100), 100) + .flatMap(x => { + Thread.sleep(20) + for (i <- Range(0, 100)) yield (x % 26, x % 10) + }).toDF("index", "pv") + val dim = Range(0, 26).map(x => (x, ('a' + x).toChar.toString)) + .toDF("index", "name") + val testDf = df.groupBy("index") + .agg(sum($"pv").alias("pv")) + .join(dim, Seq("index")) + withSQLConf(SQLConf.BROADCAST_TIMEOUT.key -> broadcastTimeoutInSec.toString, + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + val startTime = System.currentTimeMillis() + val result = testDf.collect() + val queryTime = System.currentTimeMillis() - startTime + assert(result.length == 26) + // make sure the execution time is large enough + assert(queryTime > (broadcastTimeoutInSec + 1) * 1000) + } + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 8e1db0cfbeed0..98a1089709b92 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -425,29 +425,6 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils } } - test("SPARK-33933: AQE broadcast should not timeout with slow map tasks") { - val broadcastTimeoutInSec = 1 - val df = spark.sparkContext.parallelize(Range(0, 100), 100) - .flatMap(x => { - Thread.sleep(20) - for (i <- Range(0, 100)) yield (x % 26, x % 10) - }).toDF("index", "pv") - val dim = Range(0, 26).map(x => (x, ('a' + x).toChar.toString)) - .toDF("index", "name") - val testDf = df.groupBy("index") - .agg(sum($"pv").alias("pv")) - .join(dim, Seq("index")) - withSQLConf(SQLConf.BROADCAST_TIMEOUT.key -> broadcastTimeoutInSec.toString, - SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { - val startTime = System.currentTimeMillis() - val result = testDf.collect() - val queryTime = System.currentTimeMillis() - startTime - assert(result.length == 26) - // make sure the execution time is large enough - assert(queryTime > (broadcastTimeoutInSec + 1) * 1000) - } - } - test("broadcast join where streamed side's output partitioning is HashPartitioning") { withTable("t1", "t3") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "500") {