Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
e93cedd
tags api
xupefei Aug 20, 2024
694db1b
Merge branch 'master' of github.com:apache/spark into reverse-api-tag
xupefei Aug 20, 2024
40610a7
rename
xupefei Aug 20, 2024
a70d7d2
address comments
xupefei Aug 21, 2024
0656e25
.
xupefei Aug 21, 2024
6b6ca7f
.
xupefei Aug 21, 2024
d3cd5f5
new approach
xupefei Aug 23, 2024
0922dd2
address comments
xupefei Aug 26, 2024
2a6fcc6
return job IDs earlier
xupefei Aug 26, 2024
ef0fddf
doc
xupefei Aug 26, 2024
f2ad163
no mention of spark session in core
xupefei Aug 27, 2024
ab00685
re
xupefei Aug 27, 2024
dd10f46
fix test
xupefei Aug 27, 2024
bc9b76d
revert some changes
xupefei Aug 28, 2024
8656810
undo
xupefei Aug 28, 2024
1dfafad
wip
xupefei Aug 28, 2024
1d4d5cc
.
xupefei Aug 29, 2024
a35c4e5
Merge branch 'master' of github.com:apache/spark into reverse-api-tag
xupefei Aug 29, 2024
d1208c4
revert unnessesary changes and fix tests
xupefei Aug 29, 2024
13342cf
comment
xupefei Aug 29, 2024
3879989
oh no
xupefei Aug 29, 2024
cf6437f
remove internal tags
xupefei Aug 30, 2024
4d7da3b
Merge branch 'master' of github.com:apache/spark into reverse-api-tag
xupefei Aug 30, 2024
b3b7cbc
test
xupefei Aug 30, 2024
7c9294e
Merge branch 'master' of github.com:apache/spark into reverse-api-tag
xupefei Aug 30, 2024
7338b1d
move doc to api
xupefei Aug 30, 2024
905bf91
fix test
xupefei Sep 3, 2024
514b5e4
address mridulm's comments
xupefei Sep 10, 2024
c6fb41f
address herman's comments
xupefei Sep 10, 2024
2a0292c
address hyukjin's comment
xupefei Sep 10, 2024
2d059b3
Merge branch 'master' of github.com:apache/spark into reverse-api-tag
xupefei Sep 10, 2024
a55c47c
scalastyle
xupefei Sep 16, 2024
e66ba0a
fmt
xupefei Sep 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ class SparkSession private[sql] (
*
* @since 3.5.0
*/
def interruptAll(): Seq[String] = {
override def interruptAll(): Seq[String] = {
client.interruptAll().getInterruptedIdsList.asScala.toSeq
}

Expand All @@ -441,7 +441,7 @@ class SparkSession private[sql] (
*
* @since 3.5.0
*/
def interruptTag(tag: String): Seq[String] = {
override def interruptTag(tag: String): Seq[String] = {
client.interruptTag(tag).getInterruptedIdsList.asScala.toSeq
}

Expand All @@ -454,7 +454,7 @@ class SparkSession private[sql] (
*
* @since 3.5.0
*/
def interruptOperation(operationId: String): Seq[String] = {
override def interruptOperation(operationId: String): Seq[String] = {
client.interruptOperation(operationId).getInterruptedIdsList.asScala.toSeq
}

Expand Down Expand Up @@ -485,65 +485,17 @@ class SparkSession private[sql] (
SparkSession.onSessionClose(this)
}

/**
* Add a tag to be assigned to all the operations started by this thread in this session.
*
* Often, a unit of execution in an application consists of multiple Spark executions.
* Application programmers can use this method to group all those jobs together and give a group
* tag. The application can use `org.apache.spark.sql.SparkSession.interruptTag` to cancel all
* running running executions with this tag. For example:
* {{{
* // In the main thread:
* spark.addTag("myjobs")
* spark.range(10).map(i => { Thread.sleep(10); i }).collect()
*
* // In a separate thread:
* spark.interruptTag("myjobs")
* }}}
*
* There may be multiple tags present at the same time, so different parts of application may
* use different tags to perform cancellation at different levels of granularity.
*
* @param tag
* The tag to be added. Cannot contain ',' (comma) character or be an empty string.
*
* @since 3.5.0
*/
def addTag(tag: String): Unit = {
client.addTag(tag)
}
/** @inheritdoc */
override def addTag(tag: String): Unit = client.addTag(tag)

/**
* Remove a tag previously added to be assigned to all the operations started by this thread in
* this session. Noop if such a tag was not added earlier.
*
* @param tag
* The tag to be removed. Cannot contain ',' (comma) character or be an empty string.
*
* @since 3.5.0
*/
def removeTag(tag: String): Unit = {
client.removeTag(tag)
}
/** @inheritdoc */
override def removeTag(tag: String): Unit = client.removeTag(tag)

/**
* Get the tags that are currently set to be assigned to all the operations started by this
* thread.
*
* @since 3.5.0
*/
def getTags(): Set[String] = {
client.getTags()
}
/** @inheritdoc */
override def getTags(): Set[String] = client.getTags()

/**
* Clear the current thread's operation tags.
*
* @since 3.5.0
*/
def clearTags(): Unit = {
client.clearTags()
}
/** @inheritdoc */
override def clearTags(): Unit = client.clearTags()

/**
* We cannot deserialize a connect [[SparkSession]] because of a class clash on the server side.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,21 +365,6 @@ object CheckConnectJvmClientCompatibility {
// Experimental
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession.registerClassFinder"),
// public
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession.interruptAll"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession.interruptTag"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession.interruptOperation"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession.addTag"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession.removeTag"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession.getTags"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession.clearTags"),
// SparkSession#Builder
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.SparkSession#Builder.remote"),
Expand Down
56 changes: 48 additions & 8 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import scala.collection.Map
import scala.collection.concurrent.{Map => ScalaConcurrentMap}
import scala.collection.immutable
import scala.collection.mutable.HashMap
import scala.concurrent.{Future, Promise}
import scala.jdk.CollectionConverters._
import scala.reflect.{classTag, ClassTag}
import scala.util.control.NonFatal
Expand Down Expand Up @@ -909,10 +910,20 @@ class SparkContext(config: SparkConf) extends Logging {
*
* @since 3.5.0
*/
def addJobTag(tag: String): Unit = {
SparkContext.throwIfInvalidTag(tag)
def addJobTag(tag: String): Unit = addJobTags(Set(tag))

/**
* Add multiple tags to be assigned to all the jobs started by this thread.
* See [[addJobTag]] for more details.
*
* @param tags The tags to be added. Cannot contain ',' (comma) character.
*
* @since 4.0.0
*/
def addJobTags(tags: Set[String]): Unit = {
tags.foreach(SparkContext.throwIfInvalidTag)
val existingTags = getJobTags()
val newTags = (existingTags + tag).mkString(SparkContext.SPARK_JOB_TAGS_SEP)
val newTags = (existingTags ++ tags).mkString(SparkContext.SPARK_JOB_TAGS_SEP)
setLocalProperty(SparkContext.SPARK_JOB_TAGS, newTags)
}

Expand All @@ -924,10 +935,20 @@ class SparkContext(config: SparkConf) extends Logging {
*
* @since 3.5.0
*/
def removeJobTag(tag: String): Unit = {
SparkContext.throwIfInvalidTag(tag)
def removeJobTag(tag: String): Unit = removeJobTags(Set(tag))

/**
* Remove multiple tags to be assigned to all the jobs started by this thread.
* See [[removeJobTag]] for more details.
*
* @param tags The tags to be removed. Cannot contain ',' (comma) character.
*
* @since 4.0.0
*/
def removeJobTags(tags: Set[String]): Unit = {
tags.foreach(SparkContext.throwIfInvalidTag)
val existingTags = getJobTags()
val newTags = (existingTags - tag).mkString(SparkContext.SPARK_JOB_TAGS_SEP)
val newTags = (existingTags -- tags).mkString(SparkContext.SPARK_JOB_TAGS_SEP)
if (newTags.isEmpty) {
clearJobTags()
} else {
Expand Down Expand Up @@ -2684,6 +2705,25 @@ class SparkContext(config: SparkConf) extends Logging {
dagScheduler.cancelJobGroup(groupId, cancelFutureJobs = true, None)
}

/**
* Cancel active jobs that have the specified tag. See `org.apache.spark.SparkContext.addJobTag`.
*
* @param tag The tag to be cancelled. Cannot contain ',' (comma) character.
* @param reason reason for cancellation.
* @return A future with [[ActiveJob]]s, allowing extraction of information such as Job ID and
* tags.
*/
private[spark] def cancelJobsWithTagWithFuture(
tag: String,
reason: String): Future[Seq[ActiveJob]] = {
SparkContext.throwIfInvalidTag(tag)
assertNotStopped()

val cancelledJobs = Promise[Seq[ActiveJob]]()
dagScheduler.cancelJobsWithTag(tag, Some(reason), Some(cancelledJobs))
cancelledJobs.future
}

/**
* Cancel active jobs that have the specified tag. See `org.apache.spark.SparkContext.addJobTag`.
*
Expand All @@ -2695,7 +2735,7 @@ class SparkContext(config: SparkConf) extends Logging {
def cancelJobsWithTag(tag: String, reason: String): Unit = {
SparkContext.throwIfInvalidTag(tag)
assertNotStopped()
dagScheduler.cancelJobsWithTag(tag, Option(reason))
dagScheduler.cancelJobsWithTag(tag, Option(reason), cancelledJobs = None)
}

/**
Expand All @@ -2708,7 +2748,7 @@ class SparkContext(config: SparkConf) extends Logging {
def cancelJobsWithTag(tag: String): Unit = {
SparkContext.throwIfInvalidTag(tag)
assertNotStopped()
dagScheduler.cancelJobsWithTag(tag, None)
dagScheduler.cancelJobsWithTag(tag, reason = None, cancelledJobs = None)
}

/** Cancel all jobs that have been scheduled or are running. */
Expand Down
33 changes: 23 additions & 10 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import scala.annotation.tailrec
import scala.collection.Map
import scala.collection.mutable
import scala.collection.mutable.{HashMap, HashSet, ListBuffer}
import scala.concurrent.Promise
import scala.concurrent.duration._
import scala.util.control.NonFatal

Expand Down Expand Up @@ -1116,11 +1117,18 @@ private[spark] class DAGScheduler(

/**
* Cancel all jobs with a given tag.
*
* @param tag The tag to be cancelled. Cannot contain ',' (comma) character.
* @param reason reason for cancellation.
* @param cancelledJobs a promise to be completed with operation IDs being cancelled.
*/
def cancelJobsWithTag(tag: String, reason: Option[String]): Unit = {
def cancelJobsWithTag(
tag: String,
reason: Option[String],
cancelledJobs: Option[Promise[Seq[ActiveJob]]]): Unit = {
SparkContext.throwIfInvalidTag(tag)
logInfo(log"Asked to cancel jobs with tag ${MDC(TAG, tag)}")
eventProcessLoop.post(JobTagCancelled(tag, reason))
eventProcessLoop.post(JobTagCancelled(tag, reason, cancelledJobs))
}

/**
Expand Down Expand Up @@ -1234,17 +1242,22 @@ private[spark] class DAGScheduler(
jobIds.foreach(handleJobCancellation(_, Option(updatedReason)))
}

private[scheduler] def handleJobTagCancelled(tag: String, reason: Option[String]): Unit = {
// Cancel all jobs belonging that have this tag.
private[scheduler] def handleJobTagCancelled(
tag: String,
reason: Option[String],
cancelledJobs: Option[Promise[Seq[ActiveJob]]]): Unit = {
// Cancel all jobs that have all provided tags.
// First finds all active jobs with this group id, and then kill stages for them.
val jobIds = activeJobs.filter { activeJob =>
val jobsToBeCancelled = activeJobs.filter { activeJob =>
Option(activeJob.properties).exists { properties =>
Option(properties.getProperty(SparkContext.SPARK_JOB_TAGS)).getOrElse("")
.split(SparkContext.SPARK_JOB_TAGS_SEP).filter(!_.isEmpty).toSet.contains(tag)
}
}.map(_.jobId)
val updatedReason = reason.getOrElse("part of cancelled job tag %s".format(tag))
jobIds.foreach(handleJobCancellation(_, Option(updatedReason)))
}
val updatedReason =
reason.getOrElse("part of cancelled job tags %s".format(tag))
jobsToBeCancelled.map(_.jobId).foreach(handleJobCancellation(_, Option(updatedReason)))
cancelledJobs.map(_.success(jobsToBeCancelled.toSeq))
}

private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo): Unit = {
Expand Down Expand Up @@ -3113,8 +3126,8 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
case JobGroupCancelled(groupId, cancelFutureJobs, reason) =>
dagScheduler.handleJobGroupCancelled(groupId, cancelFutureJobs, reason)

case JobTagCancelled(tag, reason) =>
dagScheduler.handleJobTagCancelled(tag, reason)
case JobTagCancelled(tag, reason, cancelledJobs) =>
dagScheduler.handleJobTagCancelled(tag, reason, cancelledJobs)

case AllJobsCancelled =>
dagScheduler.doCancelAllJobs()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.scheduler

import java.util.Properties

import scala.concurrent.Promise

import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.{AccumulatorV2, CallSite}
Expand Down Expand Up @@ -71,7 +73,8 @@ private[scheduler] case class JobGroupCancelled(

private[scheduler] case class JobTagCancelled(
tagName: String,
reason: Option[String]) extends DAGSchedulerEvent
reason: Option[String],
cancelledJobs: Option[Promise[Seq[ActiveJob]]]) extends DAGSchedulerEvent

private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent

Expand Down
Loading