Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
e93cedd
tags api
xupefei Aug 20, 2024
694db1b
Merge branch 'master' of github.com:apache/spark into reverse-api-tag
xupefei Aug 20, 2024
40610a7
rename
xupefei Aug 20, 2024
a70d7d2
address comments
xupefei Aug 21, 2024
0656e25
.
xupefei Aug 21, 2024
6b6ca7f
.
xupefei Aug 21, 2024
d3cd5f5
new approach
xupefei Aug 23, 2024
0922dd2
address comments
xupefei Aug 26, 2024
2a6fcc6
return job IDs earlier
xupefei Aug 26, 2024
ef0fddf
doc
xupefei Aug 26, 2024
f2ad163
no mention of spark session in core
xupefei Aug 27, 2024
ab00685
re
xupefei Aug 27, 2024
dd10f46
fix test
xupefei Aug 27, 2024
bc9b76d
revert some changes
xupefei Aug 28, 2024
8656810
undo
xupefei Aug 28, 2024
1dfafad
wip
xupefei Aug 28, 2024
1d4d5cc
.
xupefei Aug 29, 2024
a35c4e5
Merge branch 'master' of github.com:apache/spark into reverse-api-tag
xupefei Aug 29, 2024
d1208c4
revert unnessesary changes and fix tests
xupefei Aug 29, 2024
13342cf
comment
xupefei Aug 29, 2024
3879989
oh no
xupefei Aug 29, 2024
cf6437f
remove internal tags
xupefei Aug 30, 2024
4d7da3b
Merge branch 'master' of github.com:apache/spark into reverse-api-tag
xupefei Aug 30, 2024
b3b7cbc
test
xupefei Aug 30, 2024
7c9294e
Merge branch 'master' of github.com:apache/spark into reverse-api-tag
xupefei Aug 30, 2024
7338b1d
move doc to api
xupefei Aug 30, 2024
905bf91
fix test
xupefei Sep 3, 2024
514b5e4
address mridulm's comments
xupefei Sep 10, 2024
c6fb41f
address herman's comments
xupefei Sep 10, 2024
2a0292c
address hyukjin's comment
xupefei Sep 10, 2024
2d059b3
Merge branch 'master' of github.com:apache/spark into reverse-api-tag
xupefei Sep 10, 2024
a55c47c
scalastyle
xupefei Sep 16, 2024
e66ba0a
fmt
xupefei Sep 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ class SparkSession private[sql] (
* Often, a unit of execution in an application consists of multiple Spark executions.
* Application programmers can use this method to group all those jobs together and give a group
* tag. The application can use `org.apache.spark.sql.SparkSession.interruptTag` to cancel all
* running running executions with this tag. For example:
* running executions with this tag. For example:
* {{{
* // In the main thread:
* spark.addTag("myjobs")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,21 +408,6 @@ object CheckConnectJvmClientCompatibility {
"org.apache.spark.sql.SparkSession.addArtifacts"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession.registerClassFinder"),
// public
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession.interruptAll"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession.interruptTag"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession.interruptOperation"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession.addTag"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession.removeTag"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession.getTags"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession.clearTags"),
// SparkSession#Builder
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession#Builder.remote"),
Expand Down
82 changes: 77 additions & 5 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import scala.collection.Map
import scala.collection.concurrent.{Map => ScalaConcurrentMap}
import scala.collection.immutable
import scala.collection.mutable.HashMap
import scala.concurrent.{Future, Promise}
import scala.jdk.CollectionConverters._
import scala.reflect.{classTag, ClassTag}
import scala.util.control.NonFatal
Expand Down Expand Up @@ -825,6 +826,11 @@ class SparkContext(config: SparkConf) extends Logging {
def getLocalProperty(key: String): String =
Option(localProperties.get).map(_.getProperty(key)).orNull

/** Set the UUID of the Spark session that starts the current job. */
def setSparkSessionUUID(uuid: String): Unit = {
setLocalProperty(SparkContext.SPARK_SESSION_UUID, uuid)
}

/** Set a human readable description of the current job. */
def setJobDescription(value: String): Unit = {
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value)
Expand Down Expand Up @@ -2684,6 +2690,29 @@ class SparkContext(config: SparkConf) extends Logging {
dagScheduler.cancelJobGroup(groupId, cancelFutureJobs = true, None)
}

/**
* Cancel active jobs that have the specified tag. See `org.apache.spark.SparkContext.addJobTag`.
*
* @param tag The tag to be cancelled. Cannot contain ',' (comma) character.
* @param reason reason for cancellation
* @param shouldCancelJob Callback function to be called with the job ID of each job that matches
* the given tag. If the function returns true, the job will be cancelled.
* @return A future that will be completed with the set of job IDs that were cancelled.
*
* @since 4.0.0
*/
def cancelJobsWithTag(
tag: String,
reason: String,
shouldCancelJob: ActiveJob => Boolean): Future[Set[Int]] = {
SparkContext.throwIfInvalidTag(tag)
assertNotStopped()

val cancelledJobs = Promise[Set[Int]]()
dagScheduler.cancelJobsWithTag(tag, Option(reason), Some(shouldCancelJob), Some(cancelledJobs))
cancelledJobs.future
}

/**
* Cancel active jobs that have the specified tag. See `org.apache.spark.SparkContext.addJobTag`.
*
Expand All @@ -2695,7 +2724,11 @@ class SparkContext(config: SparkConf) extends Logging {
def cancelJobsWithTag(tag: String, reason: String): Unit = {
SparkContext.throwIfInvalidTag(tag)
assertNotStopped()
dagScheduler.cancelJobsWithTag(tag, Option(reason))
dagScheduler.cancelJobsWithTag(
tag,
Option(reason),
shouldCancelJob = None,
cancelledJobs = None)
}

/**
Expand All @@ -2708,13 +2741,51 @@ class SparkContext(config: SparkConf) extends Logging {
def cancelJobsWithTag(tag: String): Unit = {
SparkContext.throwIfInvalidTag(tag)
assertNotStopped()
dagScheduler.cancelJobsWithTag(tag, None)
dagScheduler.cancelJobsWithTag(
tag,
reason = None,
shouldCancelJob = None,
cancelledJobs = None)
}

/**
* Cancel all jobs that have been scheduled or are running.
*
* @param shouldCancelJob Callback function to be called with the job ID of each job that matches
* the given tag. If the function returns true, the job will be cancelled.
* @return A future that will be completed with the set of job IDs that were cancelled.
*/
def cancelAllJobs(shouldCancelJob: ActiveJob => Boolean): Future[Set[Int]] = {
assertNotStopped()

val cancelledJobs = Promise[Set[Int]]()
dagScheduler.cancelAllJobs(Some(shouldCancelJob), Some(cancelledJobs))
cancelledJobs.future
}

/** Cancel all jobs that have been scheduled or are running. */
def cancelAllJobs(): Unit = {
assertNotStopped()
dagScheduler.cancelAllJobs()
dagScheduler.cancelAllJobs(shouldCancelJob = None, cancelledJobs = None)
}

/**
* Cancel a given job if it's scheduled or running.
*
* @param jobId the job ID to cancel
* @param reason reason for cancellation
* @param shouldCancelJob Callback function to be called with the job ID of each job that matches
* the given tag. If the function returns true, the job will be cancelled.
* @return A future that will be completed with the set of job IDs that were cancelled.
* @note Throws `InterruptedException` if the cancel message cannot be sent
*/
def cancelJob(
jobId: Int,
reason: String,
shouldCancelJob: ActiveJob => Boolean): Future[Set[Int]] = {
val cancelledJobs = Promise[Set[Int]]()
dagScheduler.cancelJob(jobId, Option(reason), Some(shouldCancelJob), Some(cancelledJobs))
cancelledJobs.future
}

/**
Expand All @@ -2725,7 +2796,7 @@ class SparkContext(config: SparkConf) extends Logging {
* @note Throws `InterruptedException` if the cancel message cannot be sent
*/
def cancelJob(jobId: Int, reason: String): Unit = {
dagScheduler.cancelJob(jobId, Option(reason))
dagScheduler.cancelJob(jobId, Option(reason), shouldCancelJob = None, cancelledJobs = None)
}

/**
Expand All @@ -2735,7 +2806,7 @@ class SparkContext(config: SparkConf) extends Logging {
* @note Throws `InterruptedException` if the cancel message cannot be sent
*/
def cancelJob(jobId: Int): Unit = {
dagScheduler.cancelJob(jobId, None)
dagScheduler.cancelJob(jobId, reason = None, shouldCancelJob = None, cancelledJobs = None)
}

/**
Expand Down Expand Up @@ -3084,6 +3155,7 @@ object SparkContext extends Logging {
private[spark] val SPARK_JOB_INTERRUPT_ON_CANCEL = "spark.job.interruptOnCancel"
private[spark] val SPARK_JOB_TAGS = "spark.job.tags"
private[spark] val SPARK_SCHEDULER_POOL = "spark.scheduler.pool"
private[spark] val SPARK_SESSION_UUID = "spark.sparkSession.uuid"
private[spark] val RDD_SCOPE_KEY = "spark.rdd.scope"
private[spark] val RDD_SCOPE_NO_OVERRIDE_KEY = "spark.rdd.scope.noOverride"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.scheduler

import java.util.Properties

import org.apache.spark.JobArtifactSet
import org.apache.spark.{JobArtifactSet, SparkContext}
import org.apache.spark.util.CallSite

/**
Expand Down Expand Up @@ -63,4 +63,7 @@ private[spark] class ActiveJob(
val finished = Array.fill[Boolean](numPartitions)(false)

var numFinished = 0

def getSparkSessionUUID: Option[String] =
Option(properties.getProperty(SparkContext.SPARK_SESSION_UUID))
}
99 changes: 70 additions & 29 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import scala.annotation.tailrec
import scala.collection.Map
import scala.collection.mutable
import scala.collection.mutable.{HashMap, HashSet, ListBuffer}
import scala.concurrent.Promise
import scala.concurrent.duration._
import scala.util.control.NonFatal

Expand Down Expand Up @@ -144,7 +145,7 @@ private[spark] class DAGScheduler(
private[spark] val metricsSource: DAGSchedulerSource = new DAGSchedulerSource(this)

private[scheduler] val nextJobId = new AtomicInteger(0)
private[scheduler] def numTotalJobs: Int = nextJobId.get()
private[spark] def numTotalJobs: Int = nextJobId.get()
private val nextStageId = new AtomicInteger(0)

private[scheduler] val jobIdToStageIds = new HashMap[Int, HashSet[Int]]
Expand All @@ -167,7 +168,7 @@ private[spark] class DAGScheduler(
// Stages that must be resubmitted due to fetch failures
private[scheduler] val failedStages = new HashSet[Stage]

private[scheduler] val activeJobs = new HashSet[ActiveJob]
private[spark] val activeJobs = new HashSet[ActiveJob]

// Job groups that are cancelled with `cancelFutureJobs` as true, with at most
// `NUM_CANCELLED_JOB_GROUPS_TO_TRACK` stored. On a new job submission, if its job group is in
Expand Down Expand Up @@ -1099,9 +1100,13 @@ private[spark] class DAGScheduler(
/**
* Cancel a job that is running or waiting in the queue.
*/
def cancelJob(jobId: Int, reason: Option[String]): Unit = {
def cancelJob(
jobId: Int,
reason: Option[String],
shouldCancelJob: Option[ActiveJob => Boolean],
cancelledJobs: Option[Promise[Set[Int]]]): Unit = {
logInfo(log"Asked to cancel job ${MDC(JOB_ID, jobId)}")
eventProcessLoop.post(JobCancelled(jobId, reason))
eventProcessLoop.post(JobCancelled(jobId, reason, shouldCancelJob, cancelledJobs))
}

/**
Expand All @@ -1116,26 +1121,38 @@ private[spark] class DAGScheduler(

/**
* Cancel all jobs with a given tag.
*
* @param tag The tag to be cancelled. Cannot contain ',' (comma) character.
* @param reason reason for cancellation
* @param shouldCancelJob Callback function to be called with the job ID of each job that matches
* the given tag. If the function returns true, the job will be cancelled.
*/
def cancelJobsWithTag(tag: String, reason: Option[String]): Unit = {
def cancelJobsWithTag(
tag: String,
reason: Option[String],
shouldCancelJob: Option[ActiveJob => Boolean],
cancelledJobs: Option[Promise[Set[Int]]]): Unit = {
SparkContext.throwIfInvalidTag(tag)
logInfo(log"Asked to cancel jobs with tag ${MDC(TAG, tag)}")
eventProcessLoop.post(JobTagCancelled(tag, reason))
eventProcessLoop.post(JobTagCancelled(tag, reason, shouldCancelJob, cancelledJobs))
}

/**
* Cancel all jobs that are running or waiting in the queue.
*/
def cancelAllJobs(): Unit = {
eventProcessLoop.post(AllJobsCancelled)
def cancelAllJobs(
shouldCancelJob: Option[ActiveJob => Boolean],
cancelledJobs: Option[Promise[Set[Int]]]): Unit = {
eventProcessLoop.post(AllJobsCancelled(shouldCancelJob, cancelledJobs))
}

private[scheduler] def doCancelAllJobs(): Unit = {
def doCancelAllJobs(
shouldCancelJob: Option[ActiveJob => Boolean],
cancelledJobs: Option[Promise[Set[Int]]]): Unit = {
// Cancel all running jobs.
runningStages.map(_.firstJobId).foreach(handleJobCancellation(_,
Option("as part of cancellation of all jobs")))
activeJobs.clear() // These should already be empty by this point,
jobIdToActiveJob.clear() // but just in case we lost track of some jobs...
val cancelled = runningStages.map(_.firstJobId)
.filter(doJobCancellation(_, Option("as part of cancellation of all jobs"), shouldCancelJob))
cancelledJobs.foreach(_.success(cancelled.toSet))
}

/**
Expand Down Expand Up @@ -1231,10 +1248,14 @@ private[spark] class DAGScheduler(
}
val jobIds = activeInGroup.map(_.jobId)
val updatedReason = reason.getOrElse("part of cancelled job group %s".format(groupId))
jobIds.foreach(handleJobCancellation(_, Option(updatedReason)))
jobIds.foreach(doJobCancellation(_, Option(updatedReason), shouldCancelJob = None))
}

private[scheduler] def handleJobTagCancelled(tag: String, reason: Option[String]): Unit = {
private[scheduler] def handleJobTagCancelled(
tag: String,
reason: Option[String],
shouldCancelJob: Option[ActiveJob => Boolean],
cancelledJobs: Option[Promise[Set[Int]]]): Unit = {
// Cancel all jobs belonging that have this tag.
// First finds all active jobs with this group id, and then kill stages for them.
val jobIds = activeJobs.filter { activeJob =>
Expand All @@ -1244,7 +1265,8 @@ private[spark] class DAGScheduler(
}
}.map(_.jobId)
val updatedReason = reason.getOrElse("part of cancelled job tag %s".format(tag))
jobIds.foreach(handleJobCancellation(_, Option(updatedReason)))
val cancelled = jobIds.filter(doJobCancellation(_, Option(updatedReason), shouldCancelJob))
cancelledJobs.foreach(_.success(cancelled.toSet))
}

private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo): Unit = {
Expand Down Expand Up @@ -2801,21 +2823,40 @@ private[spark] class DAGScheduler(
case None =>
s"because Stage $stageId was cancelled"
}
handleJobCancellation(jobId, Option(reasonStr))
doJobCancellation(jobId, Option(reasonStr), shouldCancelJob = None)
}
case None =>
logInfo(log"No active jobs to kill for Stage ${MDC(STAGE_ID, stageId)}")
}
}

private[scheduler] def handleJobCancellation(jobId: Int, reason: Option[String]): Unit = {
private[scheduler] def handleJobCancellation(
jobId: Int,
reason: Option[String],
shouldCancelJob: Option[ActiveJob => Boolean],
cancelledJobs: Option[Promise[Set[Int]]]): Unit = {
val cancelled = Set(jobId).filter(doJobCancellation(_, reason, shouldCancelJob))
cancelledJobs.foreach(_.success(cancelled))
}

private def doJobCancellation(
jobId: Int,
reason: Option[String],
shouldCancelJob: Option[ActiveJob => Boolean]): Boolean = {
if (!jobIdToStageIds.contains(jobId)) {
logDebug("Trying to cancel unregistered job " + jobId)
false
} else {
failJobAndIndependentStages(
job = jobIdToActiveJob(jobId),
error = SparkCoreErrors.sparkJobCancelled(jobId, reason.getOrElse(""), null)
)
val activeJob = jobIdToActiveJob(jobId)
if (shouldCancelJob.forall(_(activeJob))) {
failJobAndIndependentStages(
job = activeJob,
error = SparkCoreErrors.sparkJobCancelled(jobId, reason.getOrElse(""), null)
)
true
} else {
false
}
}
}

Expand Down Expand Up @@ -3107,17 +3148,17 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
case StageCancelled(stageId, reason) =>
dagScheduler.handleStageCancellation(stageId, reason)

case JobCancelled(jobId, reason) =>
dagScheduler.handleJobCancellation(jobId, reason)
case JobCancelled(jobId, reason, shouldCancelJob, cancelledJobs) =>
dagScheduler.handleJobCancellation(jobId, reason, shouldCancelJob, cancelledJobs)

case JobGroupCancelled(groupId, cancelFutureJobs, reason) =>
dagScheduler.handleJobGroupCancelled(groupId, cancelFutureJobs, reason)

case JobTagCancelled(tag, reason) =>
dagScheduler.handleJobTagCancelled(tag, reason)
case JobTagCancelled(tag, reason, shouldCancelJob, cancelledJobs) =>
dagScheduler.handleJobTagCancelled(tag, reason, shouldCancelJob, cancelledJobs)

case AllJobsCancelled =>
dagScheduler.doCancelAllJobs()
case AllJobsCancelled(shouldCancelJob, cancelledJobs) =>
dagScheduler.doCancelAllJobs(shouldCancelJob, cancelledJobs)

case ExecutorAdded(execId, host) =>
dagScheduler.handleExecutorAdded(execId, host)
Expand Down Expand Up @@ -3173,7 +3214,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
override def onError(e: Throwable): Unit = {
logError("DAGSchedulerEventProcessLoop failed; shutting down SparkContext", e)
try {
dagScheduler.doCancelAllJobs()
dagScheduler.doCancelAllJobs(shouldCancelJob = None, cancelledJobs = None)
} catch {
case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t)
}
Expand Down
Loading