Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions core/src/main/scala/org/apache/spark/FutureAction.scala
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ trait FutureAction[T] extends Future[T] {
*/
@DeveloperApi
class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T)
extends FutureAction[T] {
extends FutureAction[T] with SupportForceFinish {

@volatile private var _cancelled: Boolean = false

Expand All @@ -120,6 +120,10 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc:
jobWaiter.cancel()
}

override def forceFinish(): Unit = {
jobWaiter.forceFinish()
}

override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = {
jobWaiter.completionFuture.ready(atMost)
this
Expand Down Expand Up @@ -172,6 +176,13 @@ trait JobSubmitter {
resultFunc: => R): FutureAction[R]
}

trait SupportForceFinish {
/**
* Force finish the execution of this action.
*/
def forceFinish(): Unit
}


/**
* A [[FutureAction]] for actions that could trigger multiple Spark jobs. Examples include take,
Expand All @@ -180,7 +191,7 @@ trait JobSubmitter {
*/
@DeveloperApi
class ComplexFutureAction[T](run : JobSubmitter => Future[T])
extends FutureAction[T] { self =>
extends FutureAction[T] with SupportForceFinish { self =>

@volatile private var _cancelled = false

Expand All @@ -195,6 +206,14 @@ class ComplexFutureAction[T](run : JobSubmitter => Future[T])
subActions.foreach(_.cancel())
}

override def forceFinish(): Unit = {
subActions.foreach {
case s: SupportForceFinish =>
s.forceFinish()
case _ =>
}
}

private def jobSubmitter = new JobSubmitter {
def submitJob[T, U, R](
rdd: RDD[T],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ private[spark] class ApproximateActionListener[T, U, R](
}
}

override def forceFinish(result: Option[String]): Unit = {
finishedTasks = totalTasks
}

override def jobFailed(exception: Exception): Unit = {
synchronized {
failure = Some(exception)
Expand All @@ -73,7 +77,7 @@ private[spark] class ApproximateActionListener[T, U, R](
val time = System.currentTimeMillis()
if (failure.isDefined) {
throw failure.get
} else if (finishedTasks == totalTasks) {
} else if (finishedTasks >= totalTasks) {
return new PartialResult(evaluator.currentResult(), true)
} else if (time >= finishTime) {
resultObject = Some(new PartialResult(evaluator.currentResult(), false))
Expand Down
31 changes: 29 additions & 2 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1118,6 +1118,13 @@ private[spark] class DAGScheduler(
eventProcessLoop.post(StageCancelled(stageId, reason))
}

/**
* Force finish job that is running or waiting in the queue.
*/
def forceFinshJob(jobId: Int, reason: Option[String]): Unit = {
eventProcessLoop.post(ForceFinishJob(jobId, reason))
}

/**
* Receives notification about shuffle push for a given shuffle from one map
* task has completed
Expand Down Expand Up @@ -2721,6 +2728,20 @@ private[spark] class DAGScheduler(
}
}

private[scheduler] def handleForceFinishJob(jobId: Int, reason: Option[String]): Unit = {
if (!jobIdToStageIds.contains(jobId)) {
logDebug("Trying to cancel unregistered job " + jobId)
} else {
val job = jobIdToActiveJob(jobId)
if (cancelRunningIndependentStages(job, "Unnecessary Stage", false)) {
cleanupStateForJobAndIndependentStages(job)
listenerBus.post(SparkListenerJobForceFinish(job.jobId, clock.getTimeMillis()))
listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded))
job.listener.forceFinish(reason)
}
}
}

/**
* Marks a stage as finished and removes it from the list of running stages.
*/
Expand Down Expand Up @@ -2798,7 +2819,10 @@ private[spark] class DAGScheduler(
}

/** Cancel all independent, running stages that are only used by this job. */
private def cancelRunningIndependentStages(job: ActiveJob, reason: String): Boolean = {
private def cancelRunningIndependentStages(
job: ActiveJob,
reason: String,
byError: Boolean = true): Boolean = {
var ableToCancelStages = true
val stages = jobIdToStageIds(job.jobId)
if (stages.isEmpty) {
Expand All @@ -2819,7 +2843,7 @@ private[spark] class DAGScheduler(
if (runningStages.contains(stage)) {
try { // cancelTasks will fail if a SchedulerBackend does not implement killTask
taskScheduler.cancelTasks(stageId, shouldInterruptTaskThread(job), reason)
markStageAsFinished(stage, Some(reason))
markStageAsFinished(stage, if (byError) Some(reason) else None)
} catch {
case e: UnsupportedOperationException =>
logWarning(s"Could not cancel tasks for stage $stageId", e)
Expand Down Expand Up @@ -3000,6 +3024,9 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
case JobCancelled(jobId, reason) =>
dagScheduler.handleJobCancellation(jobId, reason)

case ForceFinishJob(jobId, reason) =>
dagScheduler.handleForceFinishJob(jobId, reason)

case JobGroupCancelled(groupId) =>
dagScheduler.handleJobGroupCancelled(groupId)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ private[scheduler] case class JobCancelled(
reason: Option[String])
extends DAGSchedulerEvent

private[scheduler] case class ForceFinishJob(
jobId: Int,
reason: Option[String])
extends DAGSchedulerEvent

private[scheduler] case class JobGroupCancelled(groupId: String) extends DAGSchedulerEvent

private[scheduler] case class JobTagCancelled(tagName: String) extends DAGSchedulerEvent
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@ package org.apache.spark.scheduler
*/
private[spark] trait JobListener {
def taskSucceeded(index: Int, result: Any): Unit
def forceFinish(result: Option[String]): Unit
def jobFailed(exception: Exception): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ private[spark] class JobWaiter[T](
}
}

override def forceFinish(result: Option[String]): Unit = {
jobPromise.success(())
}

def forceFinish(): Unit = {
dagScheduler.forceFinshJob(jobId, Some("Unnecessary stage"))
}


override def jobFailed(exception: Exception): Unit = {
if (!jobPromise.tryFailure(exception)) {
logWarning("Ignore failure", exception)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ case class SparkListenerJobEnd(
jobResult: JobResult)
extends SparkListenerEvent

@DeveloperApi
case class SparkListenerJobForceFinish(jobId: Int, time: Long)
extends SparkListenerEvent

@DeveloperApi
case class SparkListenerEnvironmentUpdate(
environmentDetails: Map[String, collection.Seq[(String, String)]])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
var failure: Exception = _
val jobListener = new JobListener() {
override def taskSucceeded(index: Int, result: Any) = results.put(index, result)
override def forceFinish(result: Option[String]): Unit = {}
override def jobFailed(exception: Exception) = { failure = exception }
}

Expand All @@ -323,6 +324,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
val results = new HashMap[Int, Any]
var failure: Exception = null
override def taskSucceeded(index: Int, result: Any): Unit = results.put(index, result)
override def forceFinish(result: Option[String]): Unit = {}
override def jobFailed(exception: Exception): Unit = { failure = exception }
}

Expand Down Expand Up @@ -519,6 +521,11 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
runEvent(JobCancelled(jobId, None))
}

/** Sends ForceFinishJob to the DAG scheduler. */
private def forceFinishJob(jobId: Int): Unit = {
runEvent(ForceFinishJob(jobId, None))
}

/** Make some tasks in task set success and check results. */
private def completeAndCheckAnswer(
taskSet: TaskSet,
Expand Down Expand Up @@ -715,6 +722,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
var failureReason: Option[Exception] = None
val fakeListener = new JobListener() {
override def taskSucceeded(partition: Int, value: Any): Unit = numResults += 1
override def forceFinish(result: Option[String]): Unit = {}
override def jobFailed(exception: Exception): Unit = {
failureReason = Some(exception)
}
Expand Down Expand Up @@ -845,6 +853,15 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
assertDataStructuresEmpty()
}

test("trivial job force finish") {
val rdd = new MyRDD(sc, 1, Nil)
val jobId = submit(rdd, Array(0))
forceFinishJob(jobId)
assert(sparkListener.failedStages === Seq())
assert(sparkListener.successfulStages === Set(0))
assertDataStructuresEmpty()
}

test("job cancellation no-kill backend") {
// make sure that the DAGScheduler doesn't crash when the TaskScheduler
// doesn't implement killTask()
Expand Down Expand Up @@ -1929,6 +1946,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
class FailureRecordingJobListener() extends JobListener {
var failureMessage: String = _
override def taskSucceeded(index: Int, result: Any): Unit = {}
override def forceFinish(result: Option[String]): Unit = {}
override def jobFailed(exception: Exception): Unit = { failureMessage = exception.getMessage }
}
val listener1 = new FailureRecordingJobListener()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase {
// - positive value means an estimated row count which can be over-estimated
// - none means the plan has not materialized or the plan can not be estimated
private def getEstimatedRowCount(plan: LogicalPlan): Option[BigInt] = plan match {
case LogicalQueryStage(_, stage: QueryStageExec) if stage.isMaterialized =>
case LogicalQueryStage(_, _, stage: QueryStageExec) if stage.isMaterialized =>
stage.getRuntimeStatistics.rowCount

case LogicalQueryStage(_, agg: BaseAggregateExec) if agg.groupingExpressions.nonEmpty &&
case LogicalQueryStage(_, _, agg: BaseAggregateExec) if agg.groupingExpressions.nonEmpty &&
agg.child.isInstanceOf[QueryStageExec] =>
val stage = agg.child.asInstanceOf[QueryStageExec]
if (stage.isMaterialized) {
Expand All @@ -65,7 +65,7 @@ object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase {
}

private def isRelationWithAllNullKeys(plan: LogicalPlan): Boolean = plan match {
case LogicalQueryStage(_, stage: BroadcastQueryStageExec) if stage.isMaterialized =>
case LogicalQueryStage(_, _, stage: BroadcastQueryStageExec) if stage.isMaterialized =>
stage.broadcast.relationFuture.get().value == HashedRelationWithAllNullKeys
case _ => false
}
Expand All @@ -76,7 +76,7 @@ object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase {
}

override protected def userSpecifiedRepartition(p: LogicalPlan): Boolean = p match {
case LogicalQueryStage(_, ShuffleQueryStageExec(_, shuffle: ShuffleExchangeLike, _))
case LogicalQueryStage(_, _, ShuffleQueryStageExec(_, shuffle: ShuffleExchangeLike, _))
if shuffle.shuffleOrigin == REPARTITION_BY_COL ||
shuffle.shuffleOrigin == REPARTITION_BY_NUM => true
case _ => false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,19 @@ case class AdaptiveSparkPlanExec(

def finalPhysicalPlan: SparkPlan = withFinalPlanUpdate(identity)

private def cancelUselessUnfinishedStage(
newLogicalPlan: LogicalPlan,
stagesToReplace: Seq[QueryStageExec]): Set[Int] = {
var uselessStageMap = stagesToReplace.map(s => s.id -> s).toMap
newLogicalPlan.foreachUp {
case stage: LogicalQueryStage if uselessStageMap.contains(stage.stageId) =>
uselessStageMap = uselessStageMap - stage.stageId
case _ =>
}
uselessStageMap.values.foreach(_.forceFinish())
uselessStageMap.keySet
}

private def getFinalPhysicalPlan(): SparkPlan = lock.synchronized {
if (isFinalPlan) return currentPhysicalPlan

Expand All @@ -262,6 +275,7 @@ case class AdaptiveSparkPlanExec(
val events = new LinkedBlockingQueue[StageMaterializationEvent]()
val errors = new mutable.ArrayBuffer[Throwable]()
var stagesToReplace = Seq.empty[QueryStageExec]
val uselessStagesId = mutable.Set.empty[Int]
while (!result.allChildStagesMaterialized) {
currentPhysicalPlan = result.newPlan
if (result.newStages.nonEmpty) {
Expand Down Expand Up @@ -309,7 +323,7 @@ case class AdaptiveSparkPlanExec(
case StageSuccess(stage, res) =>
stage.resultOption.set(Some(res))
case StageFailure(stage, ex) =>
errors.append(ex)
if (!uselessStagesId.contains(stage.id)) errors.append(ex)
}

// In case of errors, we cancel all running stages and throw exception.
Expand Down Expand Up @@ -341,6 +355,7 @@ case class AdaptiveSparkPlanExec(
cleanUpTempTags(newPhysicalPlan)
currentPhysicalPlan = newPhysicalPlan
currentLogicalPlan = newLogicalPlan
uselessStagesId ++= cancelUselessUnfinishedStage(newLogicalPlan, stagesToReplace)
stagesToReplace = Seq.empty[QueryStageExec]
}
}
Expand Down Expand Up @@ -678,7 +693,7 @@ case class AdaptiveSparkPlanExec(
// can be overwritten through re-planning processes.
setTempTagRecursive(physicalNode.get, logicalNode)
// Replace the corresponding logical node with LogicalQueryStage
val newLogicalNode = LogicalQueryStage(logicalNode, physicalNode.get)
val newLogicalNode = LogicalQueryStage(stage.id, logicalNode, physicalNode.get)
val newLogicalPlan = logicalPlan.transformDown {
case p if p.eq(logicalNode) => newLogicalNode
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ object DynamicJoinSelection extends Rule[LogicalPlan] with JoinSelectionHelper {
isLeft: Boolean): Option[JoinStrategyHint] = {
val plan = if (isLeft) join.left else join.right
plan match {
case LogicalQueryStage(_, stage: ShuffleQueryStageExec) if stage.isMaterialized
case LogicalQueryStage(_, _, stage: ShuffleQueryStageExec) if stage.isMaterialized
&& stage.mapStats.isDefined =>

val manyEmptyInPlan = hasManyEmptyPartitions(stage.mapStats.get)
val canBroadcastPlan = (isLeft && canBuildBroadcastLeft(join.joinType)) ||
(!isLeft && canBuildBroadcastRight(join.joinType))
val manyEmptyInOther = (if (isLeft) join.right else join.left) match {
case LogicalQueryStage(_, stage: ShuffleQueryStageExec) if stage.isMaterialized
case LogicalQueryStage(_, _, stage: ShuffleQueryStageExec) if stage.isMaterialized
&& stage.mapStats.isDefined => hasManyEmptyPartitions(stage.mapStats.get)
case _ => false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
// TODO we can potentially include only [[QueryStageExec]] in this class if we make the aggregation
// planning aware of partitioning.
case class LogicalQueryStage(
stageId: Int,
logicalPlan: LogicalPlan,
physicalPlan: SparkPlan) extends LeafNode {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNes
object LogicalQueryStageStrategy extends Strategy {

private def isBroadcastStage(plan: LogicalPlan): Boolean = plan match {
case LogicalQueryStage(_, _: BroadcastQueryStageExec) => true
case LogicalQueryStage(_, _, _: BroadcastQueryStageExec) => true
case _ => false
}

Expand Down
Loading