Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.adaptive

import java.util
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue}

import scala.collection.JavaConverters._
import scala.collection.concurrent.TrieMap
Expand Down Expand Up @@ -282,8 +282,10 @@ case class AdaptiveSparkPlanExec(

// Start materialization of all new stages and fail fast if any stages failed eagerly
reorderedNewStages.foreach { stage =>
context.runningStages.put(stage.id, stage)
try {
stage.materialize().onComplete { res =>
context.runningStages.remove(stage.id)
if (res.isSuccess) {
events.offer(StageSuccess(stage, res.get))
} else {
Expand All @@ -294,6 +296,7 @@ case class AdaptiveSparkPlanExec(
}(AdaptiveSparkPlanExec.executionContext)
} catch {
case e: Throwable =>
context.runningStages.remove(stage.id)
cleanUpAndThrowException(Seq(e), Some(stage.id))
}
}
Expand Down Expand Up @@ -348,6 +351,7 @@ case class AdaptiveSparkPlanExec(
result = createQueryStages(currentPhysicalPlan)
}

cancelRunningStages()
// Run the final plan when there's no more unfinished stages.
currentPhysicalPlan = applyPhysicalRules(
optimizeQueryStage(result.newPlan, isFinalStage = true),
Expand Down Expand Up @@ -794,6 +798,27 @@ case class AdaptiveSparkPlanExec(
}
throw e
}

/**
* Cancel stages which are still running when the main query is going to final plan.
* e.g., a inner join with empty left and large right, we will convert it to `LocalRelation`
* once left is materialized but at that time the right side stage is still running.
*/
private def cancelRunningStages(): Unit = {
if (!isSubquery && !context.runningStages.isEmpty) {
context.runningStages.values().asScala.foreach {
case stage: ExchangeQueryStageExec =>
try {
stage.cancel()
} catch {
case NonFatal(t) =>
logWarning(s"Exception in cancelling query stage: ${stage.treeString}", t)
}
case _ =>
}
context.runningStages.clear()
}
}
}

object AdaptiveSparkPlanExec {
Expand Down Expand Up @@ -850,6 +875,11 @@ case class AdaptiveExecutionContext(session: SparkSession, qe: QueryExecution) {
*/
val stageCache: TrieMap[SparkPlan, ExchangeQueryStageExec] =
new TrieMap[SparkPlan, ExchangeQueryStageExec]()

/**
* The running stage map of the entire query, including sub-queries.
*/
val runningStages = new ConcurrentHashMap[Int, QueryStageExec]
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,11 @@ case class BroadcastExchangeExec(
relationFuture.cancel(true)
}
throw QueryExecutionErrors.executeBroadcastTimeoutError(timeout, Some(ex))
case (_: CancellationException | _: InterruptedException) if conf.adaptiveExecutionEnabled =>
// This happens when cancel a broadcast job and we only support cancel broadcast job
// in AQE. Then if throw exception, it would be out of control of AQE.
// The returned value should never be used, so just return null.
null
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.scalatest.PrivateMethodTester
import org.scalatest.time.SpanSugar._

import org.apache.spark.SparkException
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart}
import org.apache.spark.scheduler.{JobFailed, SparkListener, SparkListenerEvent, SparkListenerJobEnd, SparkListenerJobStart}
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
Expand Down Expand Up @@ -2841,6 +2841,46 @@ class AdaptiveQueryExecSuite
}
}
}

test("SPARK-37043: Cancel all running job when main query is going to final plan") {
withTempView("v") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
spark.range(1).createOrReplaceTempView("v")
// We should have two jobs for the two shuffle node in Join,
// `LocalTableScanExec` doesn't need a job
var firstJob = true
var failedJob: JobFailed = null
val listener = new SparkListener {
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
if (firstJob) {
firstJob = false
} else {
jobEnd.jobResult match {
case job: JobFailed => failedJob = job
case _ =>
}
}
}
}
spark.sparkContext.addSparkListener(listener)
try {
val (origin, adaptive) = runAdaptiveAndVerifyResult(
"""
|SELECT * FROM emptyTestData t1 JOIN (
| SELECT id, java_method('java.lang.Thread', 'sleep', 3000L) FROM v
|) t2 ON t1.key = t2.id
|""".stripMargin)
assert(origin.isInstanceOf[SortMergeJoinExec])
assert(adaptive.isInstanceOf[LocalTableScanExec])
spark.sparkContext.listenerBus.waitUntilEmpty(5000)
assert(failedJob != null)
assert(failedJob.exception.getMessage.contains("cancelled"))
} finally {
spark.sparkContext.removeSparkListener(listener)
}
}
}
}
}

/**
Expand Down