diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 13c9528323ae..0d2875cd76d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.adaptive import java.util -import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue} import scala.collection.JavaConverters._ import scala.collection.concurrent.TrieMap @@ -255,6 +255,7 @@ case class AdaptiveSparkPlanExec( // Start materialization of all new stages and fail fast if any stages failed eagerly reorderedNewStages.foreach { stage => + context.runningStages.put(stage.id, stage) try { stage.materialize().onComplete { res => if (res.isSuccess) { @@ -278,8 +279,10 @@ case class AdaptiveSparkPlanExec( events.drainTo(rem) (Seq(nextMsg) ++ rem.asScala).foreach { case StageSuccess(stage, res) => + context.runningStages.remove(stage.id) stage.resultOption.set(Some(res)) case StageFailure(stage, ex) => + context.runningStages.remove(stage.id) errors.append(ex) } @@ -322,6 +325,7 @@ case class AdaptiveSparkPlanExec( Some((planChangeLogger, "AQE Post Stage Creation"))) isFinalPlan = true executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan))) + cancelRunningStages() currentPhysicalPlan } } @@ -739,6 +743,26 @@ case class AdaptiveSparkPlanExec( } throw e } + + /** + * Cancel job which is still running after the plan finished. + * e.g., a inner join with empty left and large right, we will convert it to LocalRelation + * once left is materialized but the right side job is still running. + */ + private def cancelRunningStages(): Unit = { + if (!isSubquery && !context.runningStages.isEmpty) { + logInfo(s"Cancel query stages (${context.runningStages.keys().asScala.mkString(", ")}), " + + s"because the plan is finished.") + context.runningStages.values().asScala.foreach { stage => + try { + stage.cancel() + } catch { + case NonFatal(t) => + logWarning(s"Exception in cancelling query stage: ${stage.treeString}", t) + } + } + } + } } object AdaptiveSparkPlanExec { @@ -795,6 +819,11 @@ case class AdaptiveExecutionContext(session: SparkSession, qe: QueryExecution) { */ val stageCache: TrieMap[SparkPlan, QueryStageExec] = new TrieMap[SparkPlan, QueryStageExec]() + + /** + * The running stage map of the entire query, including sub-queries. + */ + val runningStages = new ConcurrentHashMap[Int, QueryStageExec] } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index c9d10e0e6cd9..812a86a83453 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -23,7 +23,7 @@ import java.net.URI import org.apache.log4j.Level import org.scalatest.PrivateMethodTester -import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} +import org.apache.spark.scheduler.{JobFailed, SparkListener, SparkListenerEvent, SparkListenerJobEnd, SparkListenerJobStart} import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} @@ -2192,6 +2192,46 @@ class AdaptiveQueryExecSuite } } } + + test("SPARK-37043: Cancel all running job after AQE plan finished") { + spark.range(1).createOrReplaceTempView("v") + withTempView("v") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + // we should have two jobs for the two shuffle node in Join, + // LocalTableScanExec doesn't need job + @volatile var firstJob = true + @volatile var failedJob: JobFailed = null + val listener = new SparkListener { + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + if (firstJob) { + firstJob = false + } else { + jobEnd.jobResult match { + case job: JobFailed => failedJob = job + case _ => + } + } + } + } + spark.sparkContext.addSparkListener(listener) + try { + val (origin, adaptive) = runAdaptiveAndVerifyResult( + """ + |SELECT * FROM emptyTestData t1 JOIN ( + | SELECT id, java_method('java.lang.Thread', 'sleep', 3000L) FROM v + |) t2 ON t1.key = t2.id + |""".stripMargin) + assert(origin.isInstanceOf[SortMergeJoinExec]) + assert(adaptive.isInstanceOf[LocalTableScanExec]) + spark.sparkContext.listenerBus.waitUntilEmpty(5000) + assert(failedJob != null) + assert(failedJob.exception.getMessage.contains("cancelled")) + } finally { + spark.sparkContext.removeSparkListener(listener) + } + } + } + } } /**