diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 209ec88618c4..d17614a98755 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -428,7 +428,7 @@ class SparkSession private[sql] ( * * @since 3.5.0 */ - def interruptAll(): Seq[String] = { + override def interruptAll(): Seq[String] = { client.interruptAll().getInterruptedIdsList.asScala.toSeq } @@ -441,7 +441,7 @@ class SparkSession private[sql] ( * * @since 3.5.0 */ - def interruptTag(tag: String): Seq[String] = { + override def interruptTag(tag: String): Seq[String] = { client.interruptTag(tag).getInterruptedIdsList.asScala.toSeq } @@ -454,7 +454,7 @@ class SparkSession private[sql] ( * * @since 3.5.0 */ - def interruptOperation(operationId: String): Seq[String] = { + override def interruptOperation(operationId: String): Seq[String] = { client.interruptOperation(operationId).getInterruptedIdsList.asScala.toSeq } @@ -485,65 +485,17 @@ class SparkSession private[sql] ( SparkSession.onSessionClose(this) } - /** - * Add a tag to be assigned to all the operations started by this thread in this session. - * - * Often, a unit of execution in an application consists of multiple Spark executions. - * Application programmers can use this method to group all those jobs together and give a group - * tag. The application can use `org.apache.spark.sql.SparkSession.interruptTag` to cancel all - * running running executions with this tag. For example: - * {{{ - * // In the main thread: - * spark.addTag("myjobs") - * spark.range(10).map(i => { Thread.sleep(10); i }).collect() - * - * // In a separate thread: - * spark.interruptTag("myjobs") - * }}} - * - * There may be multiple tags present at the same time, so different parts of application may - * use different tags to perform cancellation at different levels of granularity. - * - * @param tag - * The tag to be added. Cannot contain ',' (comma) character or be an empty string. - * - * @since 3.5.0 - */ - def addTag(tag: String): Unit = { - client.addTag(tag) - } + /** @inheritdoc */ + override def addTag(tag: String): Unit = client.addTag(tag) - /** - * Remove a tag previously added to be assigned to all the operations started by this thread in - * this session. Noop if such a tag was not added earlier. - * - * @param tag - * The tag to be removed. Cannot contain ',' (comma) character or be an empty string. - * - * @since 3.5.0 - */ - def removeTag(tag: String): Unit = { - client.removeTag(tag) - } + /** @inheritdoc */ + override def removeTag(tag: String): Unit = client.removeTag(tag) - /** - * Get the tags that are currently set to be assigned to all the operations started by this - * thread. - * - * @since 3.5.0 - */ - def getTags(): Set[String] = { - client.getTags() - } + /** @inheritdoc */ + override def getTags(): Set[String] = client.getTags() - /** - * Clear the current thread's operation tags. - * - * @since 3.5.0 - */ - def clearTags(): Unit = { - client.clearTags() - } + /** @inheritdoc */ + override def clearTags(): Unit = client.clearTags() /** * We cannot deserialize a connect [[SparkSession]] because of a class clash on the server side. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index f4043f19eb6a..abf03cfbc672 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -365,21 +365,6 @@ object CheckConnectJvmClientCompatibility { // Experimental ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.SparkSession.registerClassFinder"), - // public - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.interruptAll"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.interruptTag"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.interruptOperation"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.addTag"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.removeTag"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.getTags"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.SparkSession.clearTags"), // SparkSession#Builder ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.SparkSession#Builder.remote"), diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 485f0abcd25e..042179d86c31 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -27,6 +27,7 @@ import scala.collection.Map import scala.collection.concurrent.{Map => ScalaConcurrentMap} import scala.collection.immutable import scala.collection.mutable.HashMap +import scala.concurrent.{Future, Promise} import scala.jdk.CollectionConverters._ import scala.reflect.{classTag, ClassTag} import scala.util.control.NonFatal @@ -909,10 +910,20 @@ class SparkContext(config: SparkConf) extends Logging { * * @since 3.5.0 */ - def addJobTag(tag: String): Unit = { - SparkContext.throwIfInvalidTag(tag) + def addJobTag(tag: String): Unit = addJobTags(Set(tag)) + + /** + * Add multiple tags to be assigned to all the jobs started by this thread. + * See [[addJobTag]] for more details. + * + * @param tags The tags to be added. Cannot contain ',' (comma) character. + * + * @since 4.0.0 + */ + def addJobTags(tags: Set[String]): Unit = { + tags.foreach(SparkContext.throwIfInvalidTag) val existingTags = getJobTags() - val newTags = (existingTags + tag).mkString(SparkContext.SPARK_JOB_TAGS_SEP) + val newTags = (existingTags ++ tags).mkString(SparkContext.SPARK_JOB_TAGS_SEP) setLocalProperty(SparkContext.SPARK_JOB_TAGS, newTags) } @@ -924,10 +935,20 @@ class SparkContext(config: SparkConf) extends Logging { * * @since 3.5.0 */ - def removeJobTag(tag: String): Unit = { - SparkContext.throwIfInvalidTag(tag) + def removeJobTag(tag: String): Unit = removeJobTags(Set(tag)) + + /** + * Remove multiple tags to be assigned to all the jobs started by this thread. + * See [[removeJobTag]] for more details. + * + * @param tags The tags to be removed. Cannot contain ',' (comma) character. + * + * @since 4.0.0 + */ + def removeJobTags(tags: Set[String]): Unit = { + tags.foreach(SparkContext.throwIfInvalidTag) val existingTags = getJobTags() - val newTags = (existingTags - tag).mkString(SparkContext.SPARK_JOB_TAGS_SEP) + val newTags = (existingTags -- tags).mkString(SparkContext.SPARK_JOB_TAGS_SEP) if (newTags.isEmpty) { clearJobTags() } else { @@ -2684,6 +2705,25 @@ class SparkContext(config: SparkConf) extends Logging { dagScheduler.cancelJobGroup(groupId, cancelFutureJobs = true, None) } + /** + * Cancel active jobs that have the specified tag. See `org.apache.spark.SparkContext.addJobTag`. + * + * @param tag The tag to be cancelled. Cannot contain ',' (comma) character. + * @param reason reason for cancellation. + * @return A future with [[ActiveJob]]s, allowing extraction of information such as Job ID and + * tags. + */ + private[spark] def cancelJobsWithTagWithFuture( + tag: String, + reason: String): Future[Seq[ActiveJob]] = { + SparkContext.throwIfInvalidTag(tag) + assertNotStopped() + + val cancelledJobs = Promise[Seq[ActiveJob]]() + dagScheduler.cancelJobsWithTag(tag, Some(reason), Some(cancelledJobs)) + cancelledJobs.future + } + /** * Cancel active jobs that have the specified tag. See `org.apache.spark.SparkContext.addJobTag`. * @@ -2695,7 +2735,7 @@ class SparkContext(config: SparkConf) extends Logging { def cancelJobsWithTag(tag: String, reason: String): Unit = { SparkContext.throwIfInvalidTag(tag) assertNotStopped() - dagScheduler.cancelJobsWithTag(tag, Option(reason)) + dagScheduler.cancelJobsWithTag(tag, Option(reason), cancelledJobs = None) } /** @@ -2708,7 +2748,7 @@ class SparkContext(config: SparkConf) extends Logging { def cancelJobsWithTag(tag: String): Unit = { SparkContext.throwIfInvalidTag(tag) assertNotStopped() - dagScheduler.cancelJobsWithTag(tag, None) + dagScheduler.cancelJobsWithTag(tag, reason = None, cancelledJobs = None) } /** Cancel all jobs that have been scheduled or are running. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 6c824e2fdeae..2c89fe7885d0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -27,6 +27,7 @@ import scala.annotation.tailrec import scala.collection.Map import scala.collection.mutable import scala.collection.mutable.{HashMap, HashSet, ListBuffer} +import scala.concurrent.Promise import scala.concurrent.duration._ import scala.util.control.NonFatal @@ -1116,11 +1117,18 @@ private[spark] class DAGScheduler( /** * Cancel all jobs with a given tag. + * + * @param tag The tag to be cancelled. Cannot contain ',' (comma) character. + * @param reason reason for cancellation. + * @param cancelledJobs a promise to be completed with operation IDs being cancelled. */ - def cancelJobsWithTag(tag: String, reason: Option[String]): Unit = { + def cancelJobsWithTag( + tag: String, + reason: Option[String], + cancelledJobs: Option[Promise[Seq[ActiveJob]]]): Unit = { SparkContext.throwIfInvalidTag(tag) logInfo(log"Asked to cancel jobs with tag ${MDC(TAG, tag)}") - eventProcessLoop.post(JobTagCancelled(tag, reason)) + eventProcessLoop.post(JobTagCancelled(tag, reason, cancelledJobs)) } /** @@ -1234,17 +1242,22 @@ private[spark] class DAGScheduler( jobIds.foreach(handleJobCancellation(_, Option(updatedReason))) } - private[scheduler] def handleJobTagCancelled(tag: String, reason: Option[String]): Unit = { - // Cancel all jobs belonging that have this tag. + private[scheduler] def handleJobTagCancelled( + tag: String, + reason: Option[String], + cancelledJobs: Option[Promise[Seq[ActiveJob]]]): Unit = { + // Cancel all jobs that have all provided tags. // First finds all active jobs with this group id, and then kill stages for them. - val jobIds = activeJobs.filter { activeJob => + val jobsToBeCancelled = activeJobs.filter { activeJob => Option(activeJob.properties).exists { properties => Option(properties.getProperty(SparkContext.SPARK_JOB_TAGS)).getOrElse("") .split(SparkContext.SPARK_JOB_TAGS_SEP).filter(!_.isEmpty).toSet.contains(tag) } - }.map(_.jobId) - val updatedReason = reason.getOrElse("part of cancelled job tag %s".format(tag)) - jobIds.foreach(handleJobCancellation(_, Option(updatedReason))) + } + val updatedReason = + reason.getOrElse("part of cancelled job tags %s".format(tag)) + jobsToBeCancelled.map(_.jobId).foreach(handleJobCancellation(_, Option(updatedReason))) + cancelledJobs.map(_.success(jobsToBeCancelled.toSeq)) } private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo): Unit = { @@ -3113,8 +3126,8 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case JobGroupCancelled(groupId, cancelFutureJobs, reason) => dagScheduler.handleJobGroupCancelled(groupId, cancelFutureJobs, reason) - case JobTagCancelled(tag, reason) => - dagScheduler.handleJobTagCancelled(tag, reason) + case JobTagCancelled(tag, reason, cancelledJobs) => + dagScheduler.handleJobTagCancelled(tag, reason, cancelledJobs) case AllJobsCancelled => dagScheduler.doCancelAllJobs() diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index c9ad54d1fdc7..8932d2ef323b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -19,6 +19,8 @@ package org.apache.spark.scheduler import java.util.Properties +import scala.concurrent.Promise + import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.util.{AccumulatorV2, CallSite} @@ -71,7 +73,8 @@ private[scheduler] case class JobGroupCancelled( private[scheduler] case class JobTagCancelled( tagName: String, - reason: Option[String]) extends DAGSchedulerEvent + reason: Option[String], + cancelledJobs: Option[Promise[Seq[ActiveJob]]]) extends DAGSchedulerEvent private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala index cf502c746d24..4767a5e1dfd2 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala @@ -390,6 +390,98 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C @scala.annotation.varargs def addArtifacts(uri: URI*): Unit + /** + * Add a tag to be assigned to all the operations started by this thread in this session. + * + * Often, a unit of execution in an application consists of multiple Spark executions. + * Application programmers can use this method to group all those jobs together and give a group + * tag. The application can use `org.apache.spark.sql.SparkSession.interruptTag` to cancel all + * running executions with this tag. For example: + * {{{ + * // In the main thread: + * spark.addTag("myjobs") + * spark.range(10).map(i => { Thread.sleep(10); i }).collect() + * + * // In a separate thread: + * spark.interruptTag("myjobs") + * }}} + * + * There may be multiple tags present at the same time, so different parts of application may + * use different tags to perform cancellation at different levels of granularity. + * + * @param tag + * The tag to be added. Cannot contain ',' (comma) character or be an empty string. + * + * @since 4.0.0 + */ + def addTag(tag: String): Unit + + /** + * Remove a tag previously added to be assigned to all the operations started by this thread in + * this session. Noop if such a tag was not added earlier. + * + * @param tag + * The tag to be removed. Cannot contain ',' (comma) character or be an empty string. + * + * @since 4.0.0 + */ + def removeTag(tag: String): Unit + + /** + * Get the operation tags that are currently set to be assigned to all the operations started by + * this thread in this session. + * + * @since 4.0.0 + */ + def getTags(): Set[String] + + /** + * Clear the current thread's operation tags. + * + * @since 4.0.0 + */ + def clearTags(): Unit + + /** + * Request to interrupt all currently running operations of this session. + * + * @note + * This method will wait up to 60 seconds for the interruption request to be issued. + * + * @return + * Sequence of operation IDs requested to be interrupted. + * + * @since 4.0.0 + */ + def interruptAll(): Seq[String] + + /** + * Request to interrupt all currently running operations of this session with the given job tag. + * + * @note + * This method will wait up to 60 seconds for the interruption request to be issued. + * + * @return + * Sequence of operation IDs requested to be interrupted. + * + * @since 4.0.0 + */ + def interruptTag(tag: String): Seq[String] + + /** + * Request to interrupt an operation of this session, given its operation ID. + * + * @note + * This method will wait up to 60 seconds for the interruption request to be issued. + * + * @return + * The operation ID requested to be interrupted, as a single-element sequence, or an empty + * sequence if the operation is not started by this session. + * + * @since 4.0.0 + */ + def interruptOperation(operationId: String): Seq[String] + /** * Returns a [[DataFrameReader]] that can be used to read non-streaming data in as a * `DataFrame`. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index a8b2044ba8a4..1ee86ae1a113 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -6853,8 +6853,8 @@ object functions { /** * Converts a column containing nested inputs (array/map/struct) into a variants where maps and - * structs are converted to variant objects which are unordered unlike SQL structs. Input maps can - * only have string keys. + * structs are converted to variant objects which are unordered unlike SQL structs. Input maps + * can only have string keys. * * @param col * a column with a nested schema or column name. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index a7fb71d95d14..bfad54637ed2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql import java.net.URI import java.nio.file.Paths import java.util.{ServiceLoader, UUID} +import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} +import scala.concurrent.duration.DurationInt import scala.jdk.CollectionConverters._ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal @@ -57,7 +59,7 @@ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.ExecutionListenerManager -import org.apache.spark.util.{CallSite, SparkFileUtils, Utils} +import org.apache.spark.util.{CallSite, SparkFileUtils, ThreadUtils, Utils} import org.apache.spark.util.ArrayImplicits._ /** @@ -92,7 +94,8 @@ class SparkSession private( @transient private val existingSharedState: Option[SharedState], @transient private val parentSessionState: Option[SessionState], @transient private[sql] val extensions: SparkSessionExtensions, - @transient private[sql] val initialSessionOptions: Map[String, String]) + @transient private[sql] val initialSessionOptions: Map[String, String], + @transient private val parentManagedJobTags: Map[String, String]) extends api.SparkSession[Dataset] with Logging { self => // The call site where this SparkSession was constructed. @@ -107,7 +110,12 @@ class SparkSession private( private[sql] def this( sc: SparkContext, initialSessionOptions: java.util.HashMap[String, String]) = { - this(sc, None, None, applyAndLoadExtensions(sc), initialSessionOptions.asScala.toMap) + this( + sc, + existingSharedState = None, + parentSessionState = None, + applyAndLoadExtensions(sc), initialSessionOptions.asScala.toMap, + parentManagedJobTags = Map.empty) } private[sql] def this(sc: SparkContext) = this(sc, new java.util.HashMap[String, String]()) @@ -122,6 +130,18 @@ class SparkSession private( .getOrElse(SQLConf.getFallbackConf) }) + /** Tag to mark all jobs owned by this session. */ + private[sql] lazy val sessionJobTag = s"spark-session-$sessionUUID" + + /** + * A map to hold the mapping from user-defined tags to the real tags attached to Jobs. + * Real tag have the current session ID attached: `"tag1" -> s"spark-session-$sessionUUID-tag1"`. + */ + @transient + private[sql] lazy val managedJobTags: ConcurrentHashMap[String, String] = { + new ConcurrentHashMap(parentManagedJobTags.asJava) + } + /** @inheritdoc */ def version: String = SPARK_VERSION @@ -243,7 +263,8 @@ class SparkSession private( Some(sharedState), parentSessionState = None, extensions, - initialSessionOptions) + initialSessionOptions, + parentManagedJobTags = Map.empty) } /** @@ -264,8 +285,10 @@ class SparkSession private( Some(sharedState), Some(sessionState), extensions, - Map.empty) + Map.empty, + managedJobTags.asScala.toMap) result.sessionState // force copy of SessionState + result.managedJobTags // force copy of userDefinedToRealTagsMap result } @@ -644,6 +667,83 @@ class SparkSession private( artifactManager.addLocalArtifacts(uri.flatMap(Artifact.parseArtifacts)) } + /** @inheritdoc */ + override def addTag(tag: String): Unit = { + SparkContext.throwIfInvalidTag(tag) + managedJobTags.put(tag, s"spark-session-$sessionUUID-$tag") + } + + /** @inheritdoc */ + override def removeTag(tag: String): Unit = managedJobTags.remove(tag) + + /** @inheritdoc */ + override def getTags(): Set[String] = managedJobTags.keys().asScala.toSet + + /** @inheritdoc */ + override def clearTags(): Unit = managedJobTags.clear() + + /** + * Request to interrupt all currently running SQL operations of this session. + * + * @note Only DataFrame/SQL operations started by this session can be interrupted. + * + * @note This method will wait up to 60 seconds for the interruption request to be issued. + + * @return Sequence of SQL execution IDs requested to be interrupted. + + * @since 4.0.0 + */ + override def interruptAll(): Seq[String] = + doInterruptTag(sessionJobTag, "as part of cancellation of all jobs") + + /** + * Request to interrupt all currently running SQL operations of this session with the given + * job tag. + * + * @note Only DataFrame/SQL operations started by this session can be interrupted. + * + * @note This method will wait up to 60 seconds for the interruption request to be issued. + * + * @return Sequence of SQL execution IDs requested to be interrupted. + + * @since 4.0.0 + */ + override def interruptTag(tag: String): Seq[String] = { + val realTag = managedJobTags.get(tag) + if (realTag == null) return Seq.empty + doInterruptTag(realTag, s"part of cancelled job tags $tag") + } + + private def doInterruptTag(tag: String, reason: String): Seq[String] = { + val cancelledTags = + sparkContext.cancelJobsWithTagWithFuture(tag, reason) + + ThreadUtils.awaitResult(cancelledTags, 60.seconds) + .flatMap(job => Option(job.properties.getProperty(SQLExecution.EXECUTION_ROOT_ID_KEY))) + } + + /** + * Request to interrupt a SQL operation of this session, given its SQL execution ID. + * + * @note Only DataFrame/SQL operations started by this session can be interrupted. + * + * @note This method will wait up to 60 seconds for the interruption request to be issued. + * + * @return The execution ID requested to be interrupted, as a single-element sequence, or an empty + * sequence if the operation is not started by this session. + * + * @since 4.0.0 + */ + override def interruptOperation(operationId: String): Seq[String] = { + scala.util.Try(operationId.toLong).toOption match { + case Some(executionIdToBeCancelled) => + val tagToBeCancelled = SQLExecution.executionIdJobTag(this, executionIdToBeCancelled) + doInterruptTag(tagToBeCancelled, reason = "") + case None => + throw new IllegalArgumentException("executionId must be a number in string form.") + } + } + /** @inheritdoc */ def read: DataFrameReader = new DataFrameReader(self) @@ -730,7 +830,7 @@ class SparkSession private( } /** - * Execute a block of code with the this session set as the active session, and restore the + * Execute a block of code with this session set as the active session, and restore the * previous session on completion. */ private[sql] def withActive[T](block: => T): T = { @@ -965,7 +1065,12 @@ object SparkSession extends Logging { loadExtensions(extensions) applyExtensions(sparkContext, extensions) - session = new SparkSession(sparkContext, None, None, extensions, options.toMap) + session = new SparkSession(sparkContext, + existingSharedState = None, + parentSessionState = None, + extensions, + initialSessionOptions = options.toMap, + parentManagedJobTags = Map.empty) setDefaultSession(session) setActiveSession(session) registerContextListener(sparkContext) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 58fff2d4a1a2..3a406f4c0d0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -44,7 +44,7 @@ object SQLExecution extends Logging { private def nextExecutionId: Long = _nextExecutionId.getAndIncrement - private val executionIdToQueryExecution = new ConcurrentHashMap[Long, QueryExecution]() + private[sql] val executionIdToQueryExecution = new ConcurrentHashMap[Long, QueryExecution]() def getQueryExecution(executionId: Long): QueryExecution = { executionIdToQueryExecution.get(executionId) @@ -52,6 +52,9 @@ object SQLExecution extends Logging { private val testing = sys.props.contains(IS_TESTING.key) + private[sql] def executionIdJobTag(session: SparkSession, id: Long) = + s"${session.sessionJobTag}-execution-root-id-$id" + private[sql] def checkSQLExecutionId(sparkSession: SparkSession): Unit = { val sc = sparkSession.sparkContext // only throw an exception during tests. a missing execution ID should not fail a job. @@ -82,6 +85,7 @@ object SQLExecution extends Logging { // And for the root execution, rootExecutionId == executionId. if (sc.getLocalProperty(EXECUTION_ROOT_ID_KEY) == null) { sc.setLocalProperty(EXECUTION_ROOT_ID_KEY, executionId.toString) + sc.addJobTag(executionIdJobTag(sparkSession, executionId)) } val rootExecutionId = sc.getLocalProperty(EXECUTION_ROOT_ID_KEY).toLong executionIdToQueryExecution.put(executionId, queryExecution) @@ -116,92 +120,94 @@ object SQLExecution extends Logging { val redactedConfigs = sparkSession.sessionState.conf.redactOptions(modifiedConfigs) withSQLConfPropagated(sparkSession) { - var ex: Option[Throwable] = None - var isExecutedPlanAvailable = false - val startTime = System.nanoTime() - val startEvent = SparkListenerSQLExecutionStart( - executionId = executionId, - rootExecutionId = Some(rootExecutionId), - description = desc, - details = callSite.longForm, - physicalPlanDescription = "", - sparkPlanInfo = SparkPlanInfo.EMPTY, - time = System.currentTimeMillis(), - modifiedConfigs = redactedConfigs, - jobTags = sc.getJobTags(), - jobGroupId = Option(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID)) - ) - try { - body match { - case Left(e) => - sc.listenerBus.post(startEvent) + withSessionTagsApplied(sparkSession) { + var ex: Option[Throwable] = None + var isExecutedPlanAvailable = false + val startTime = System.nanoTime() + val startEvent = SparkListenerSQLExecutionStart( + executionId = executionId, + rootExecutionId = Some(rootExecutionId), + description = desc, + details = callSite.longForm, + physicalPlanDescription = "", + sparkPlanInfo = SparkPlanInfo.EMPTY, + time = System.currentTimeMillis(), + modifiedConfigs = redactedConfigs, + jobTags = sc.getJobTags(), + jobGroupId = Option(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID)) + ) + try { + body match { + case Left(e) => + sc.listenerBus.post(startEvent) + throw e + case Right(f) => + val planDescriptionMode = + ExplainMode.fromString(sparkSession.sessionState.conf.uiExplainMode) + val planDesc = queryExecution.explainString(planDescriptionMode) + val planInfo = try { + SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan) + } catch { + case NonFatal(e) => + logDebug("Failed to generate SparkPlanInfo", e) + // If the queryExecution already failed before this, we are not able to generate + // the the plan info, so we use and empty graphviz node to make the UI happy + SparkPlanInfo.EMPTY + } + sc.listenerBus.post( + startEvent.copy(physicalPlanDescription = planDesc, sparkPlanInfo = planInfo)) + isExecutedPlanAvailable = true + f() + } + } catch { + case e: Throwable => + ex = Some(e) throw e - case Right(f) => - val planDescriptionMode = - ExplainMode.fromString(sparkSession.sessionState.conf.uiExplainMode) - val planDesc = queryExecution.explainString(planDescriptionMode) - val planInfo = try { - SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan) - } catch { - case NonFatal(e) => - logDebug("Failed to generate SparkPlanInfo", e) - // If the queryExecution already failed before this, we are not able to generate - // the the plan info, so we use and empty graphviz node to make the UI happy - SparkPlanInfo.EMPTY - } - sc.listenerBus.post( - startEvent.copy(physicalPlanDescription = planDesc, sparkPlanInfo = planInfo)) - isExecutedPlanAvailable = true - f() - } - } catch { - case e: Throwable => - ex = Some(e) - throw e - } finally { - val endTime = System.nanoTime() - val errorMessage = ex.map { - case e: SparkThrowable => - SparkThrowableHelper.getMessage(e, ErrorMessageFormat.PRETTY) - case e => - Utils.exceptionString(e) - } - if (queryExecution.shuffleCleanupMode != DoNotCleanup - && isExecutedPlanAvailable) { - val shuffleIds = queryExecution.executedPlan match { - case ae: AdaptiveSparkPlanExec => - ae.context.shuffleIds.asScala.keys - case _ => - Iterable.empty + } finally { + val endTime = System.nanoTime() + val errorMessage = ex.map { + case e: SparkThrowable => + SparkThrowableHelper.getMessage(e, ErrorMessageFormat.PRETTY) + case e => + Utils.exceptionString(e) } - shuffleIds.foreach { shuffleId => - queryExecution.shuffleCleanupMode match { - case RemoveShuffleFiles => - // Same as what we do in ContextCleaner.doCleanupShuffle, but do not unregister - // the shuffle on MapOutputTracker, so that stage retries would be triggered. - // Set blocking to Utils.isTesting to deflake unit tests. - sc.shuffleDriverComponents.removeShuffle(shuffleId, Utils.isTesting) - case SkipMigration => - SparkEnv.get.blockManager.migratableResolver.addShuffleToSkip(shuffleId) - case _ => // this should not happen + if (queryExecution.shuffleCleanupMode != DoNotCleanup + && isExecutedPlanAvailable) { + val shuffleIds = queryExecution.executedPlan match { + case ae: AdaptiveSparkPlanExec => + ae.context.shuffleIds.asScala.keys + case _ => + Iterable.empty + } + shuffleIds.foreach { shuffleId => + queryExecution.shuffleCleanupMode match { + case RemoveShuffleFiles => + // Same as what we do in ContextCleaner.doCleanupShuffle, but do not unregister + // the shuffle on MapOutputTracker, so that stage retries would be triggered. + // Set blocking to Utils.isTesting to deflake unit tests. + sc.shuffleDriverComponents.removeShuffle(shuffleId, Utils.isTesting) + case SkipMigration => + SparkEnv.get.blockManager.migratableResolver.addShuffleToSkip(shuffleId) + case _ => // this should not happen + } } } + val event = SparkListenerSQLExecutionEnd( + executionId, + System.currentTimeMillis(), + // Use empty string to indicate no error, as None may mean events generated by old + // versions of Spark. + errorMessage.orElse(Some(""))) + // Currently only `Dataset.withAction` and `DataFrameWriter.runCommand` specify the + // `name` parameter. The `ExecutionListenerManager` only watches SQL executions with + // name. We can specify the execution name in more places in the future, so that + // `QueryExecutionListener` can track more cases. + event.executionName = name + event.duration = endTime - startTime + event.qe = queryExecution + event.executionFailure = ex + sc.listenerBus.post(event) } - val event = SparkListenerSQLExecutionEnd( - executionId, - System.currentTimeMillis(), - // Use empty string to indicate no error, as None may mean events generated by old - // versions of Spark. - errorMessage.orElse(Some(""))) - // Currently only `Dataset.withAction` and `DataFrameWriter.runCommand` specify the `name` - // parameter. The `ExecutionListenerManager` only watches SQL executions with name. We - // can specify the execution name in more places in the future, so that - // `QueryExecutionListener` can track more cases. - event.executionName = name - event.duration = endTime - startTime - event.qe = queryExecution - event.executionFailure = ex - sc.listenerBus.post(event) } } } finally { @@ -211,6 +217,7 @@ object SQLExecution extends Logging { // The current execution is the root execution if rootExecutionId == executionId. if (sc.getLocalProperty(EXECUTION_ROOT_ID_KEY) == executionId.toString) { sc.setLocalProperty(EXECUTION_ROOT_ID_KEY, null) + sc.removeJobTag(executionIdJobTag(sparkSession, executionId)) } sc.setLocalProperty(SPARK_JOB_INTERRUPT_ON_CANCEL, originalInterruptOnCancel) } @@ -238,15 +245,28 @@ object SQLExecution extends Logging { val sc = sparkSession.sparkContext val oldExecutionId = sc.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) withSQLConfPropagated(sparkSession) { - try { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) - body - } finally { - sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + withSessionTagsApplied(sparkSession) { + try { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, executionId) + body + } finally { + sc.setLocalProperty(SQLExecution.EXECUTION_ID_KEY, oldExecutionId) + } } } } + private[sql] def withSessionTagsApplied[T](sparkSession: SparkSession)(block: => T): T = { + val allTags = sparkSession.managedJobTags.values().asScala.toSet + sparkSession.sessionJobTag + sparkSession.sparkContext.addJobTags(allTags) + + try { + block + } finally { + sparkSession.sparkContext.removeJobTags(allTags) + } + } + /** * Wrap an action with specified SQL configs. These configs will be propagated to the executor * side via job local properties. @@ -286,10 +306,13 @@ object SQLExecution extends Logging { val originalSession = SparkSession.getActiveSession val originalLocalProps = sc.getLocalProperties SparkSession.setActiveSession(activeSession) - sc.setLocalProperties(localProps) - val res = body - // reset active session and local props. - sc.setLocalProperties(originalLocalProps) + val res = withSessionTagsApplied(activeSession) { + sc.setLocalProperties(localProps) + val res = body + // reset active session and local props. + sc.setLocalProperties(originalLocalProps) + res + } if (originalSession.nonEmpty) { SparkSession.setActiveSession(originalSession.get) } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala new file mode 100644 index 000000000000..e9fd07ecf18b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionJobTaggingAndCancellationSuite.scala @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.concurrent.{ConcurrentHashMap, Semaphore, TimeUnit} +import java.util.concurrent.atomic.AtomicInteger + +import scala.concurrent.{ExecutionContext, Future} +import scala.jdk.CollectionConverters._ + +import org.scalatest.concurrent.Eventually +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd, SparkListenerJobStart} +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.tags.ExtendedSQLTest +import org.apache.spark.util.ThreadUtils + +/** + * Test cases for the tagging and cancellation APIs provided by [[SparkSession]]. + */ +@ExtendedSQLTest +class SparkSessionJobTaggingAndCancellationSuite + extends SparkFunSuite + with Eventually + with LocalSparkContext { + + override def afterEach(): Unit = { + try { + // This suite should not interfere with the other test suites. + SparkSession.getActiveSession.foreach(_.stop()) + SparkSession.clearActiveSession() + SparkSession.getDefaultSession.foreach(_.stop()) + SparkSession.clearDefaultSession() + resetSparkContext() + } finally { + super.afterEach() + } + } + + test("Tags are not inherited by new sessions") { + val session = SparkSession.builder().master("local").getOrCreate() + + assert(session.getTags() == Set()) + session.addTag("one") + assert(session.getTags() == Set("one")) + + val newSession = session.newSession() + assert(newSession.getTags() == Set()) + } + + test("Tags are inherited by cloned sessions") { + val session = SparkSession.builder().master("local").getOrCreate() + + assert(session.getTags() == Set()) + session.addTag("one") + assert(session.getTags() == Set("one")) + + val clonedSession = session.cloneSession() + assert(clonedSession.getTags() == Set("one")) + clonedSession.addTag("two") + assert(clonedSession.getTags() == Set("one", "two")) + + // Tags are not propagated back to the original session + assert(session.getTags() == Set("one")) + } + + test("Tags set from session are prefixed with session UUID") { + sc = new SparkContext("local[2]", "test") + val session = SparkSession.builder().sparkContext(sc).getOrCreate() + import session.implicits._ + + val sem = new Semaphore(0) + sc.addSparkListener(new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + sem.release() + } + }) + + session.addTag("one") + Future { + session.range(1, 10000).map { i => Thread.sleep(100); i }.count() + }(ExecutionContext.global) + + assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES)) + val activeJobsFuture = + session.sparkContext.cancelJobsWithTagWithFuture(session.managedJobTags.get("one"), "reason") + val activeJob = ThreadUtils.awaitResult(activeJobsFuture, 60.seconds).head + val actualTags = activeJob.properties.getProperty(SparkContext.SPARK_JOB_TAGS) + .split(SparkContext.SPARK_JOB_TAGS_SEP) + assert(actualTags.toSet == Set( + session.sessionJobTag, + s"${session.sessionJobTag}-one", + SQLExecution.executionIdJobTag( + session, + activeJob.properties.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong))) + } + + test("Cancellation APIs in SparkSession are isolated") { + sc = new SparkContext("local[2]", "test") + val globalSession = SparkSession.builder().sparkContext(sc).getOrCreate() + var (sessionA, sessionB, sessionC): (SparkSession, SparkSession, SparkSession) = + (null, null, null) + + // global ExecutionContext has only 2 threads in Apache Spark CI + // create own thread pool for four Futures used in this test + val numThreads = 3 + val fpool = ThreadUtils.newForkJoinPool("job-tags-test-thread-pool", numThreads) + val executionContext = ExecutionContext.fromExecutorService(fpool) + + try { + // Add a listener to release the semaphore once jobs are launched. + val sem = new Semaphore(0) + val jobEnded = new AtomicInteger(0) + val jobProperties: ConcurrentHashMap[Int, java.util.Properties] = new ConcurrentHashMap() + + sc.addSparkListener(new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + jobProperties.put(jobStart.jobId, jobStart.properties) + sem.release() + } + + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + sem.release() + jobEnded.incrementAndGet() + } + }) + + // Note: since tags are added in the Future threads, they don't need to be cleared in between. + val jobA = Future { + sessionA = globalSession.cloneSession() + import globalSession.implicits._ + + assert(sessionA.getTags() == Set()) + sessionA.addTag("two") + assert(sessionA.getTags() == Set("two")) + sessionA.clearTags() // check that clearing all tags works + assert(sessionA.getTags() == Set()) + sessionA.addTag("one") + assert(sessionA.getTags() == Set("one")) + try { + sessionA.range(1, 10000).map { i => Thread.sleep(100); i }.count() + } finally { + sessionA.clearTags() // clear for the case of thread reuse by another Future + } + }(executionContext) + val jobB = Future { + sessionB = globalSession.cloneSession() + import globalSession.implicits._ + + assert(sessionB.getTags() == Set()) + sessionB.addTag("one") + sessionB.addTag("two") + sessionB.addTag("one") + sessionB.addTag("two") // duplicates shouldn't matter + assert(sessionB.getTags() == Set("one", "two")) + try { + sessionB.range(1, 10000, 2).map { i => Thread.sleep(100); i }.count() + } finally { + sessionB.clearTags() // clear for the case of thread reuse by another Future + } + }(executionContext) + val jobC = Future { + sessionC = globalSession.cloneSession() + import globalSession.implicits._ + + sessionC.addTag("foo") + sessionC.removeTag("foo") + assert(sessionC.getTags() == Set()) // check that remove works removing the last tag + sessionC.addTag("boo") + try { + sessionC.range(1, 10000, 2).map { i => Thread.sleep(100); i }.count() + } finally { + sessionC.clearTags() // clear for the case of thread reuse by another Future + } + }(executionContext) + + // Block until four jobs have started. + assert(sem.tryAcquire(3, 1, TimeUnit.MINUTES)) + + // Tags are applied + assert(jobProperties.size == 3) + for (ss <- Seq(sessionA, sessionB, sessionC)) { + val jobProperty = jobProperties.values().asScala.filter(_.get(SparkContext.SPARK_JOB_TAGS) + .asInstanceOf[String].contains(ss.sessionUUID)) + assert(jobProperty.size == 1) + val tags = jobProperty.head.get(SparkContext.SPARK_JOB_TAGS).asInstanceOf[String] + .split(SparkContext.SPARK_JOB_TAGS_SEP) + + val executionRootIdTag = SQLExecution.executionIdJobTag( + ss, + jobProperty.head.get(SQLExecution.EXECUTION_ROOT_ID_KEY).asInstanceOf[String].toLong) + val userTagsPrefix = s"spark-session-${ss.sessionUUID}-" + + ss match { + case s if s == sessionA => assert(tags.toSet == Set( + s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}one")) + case s if s == sessionB => assert(tags.toSet == Set( + s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}one", s"${userTagsPrefix}two")) + case s if s == sessionC => assert(tags.toSet == Set( + s.sessionJobTag, executionRootIdTag, s"${userTagsPrefix}boo")) + } + } + + // Global session cancels nothing + assert(globalSession.interruptAll().isEmpty) + assert(globalSession.interruptTag("one").isEmpty) + assert(globalSession.interruptTag("two").isEmpty) + for (i <- SQLExecution.executionIdToQueryExecution.keys().asScala) { + assert(globalSession.interruptOperation(i.toString).isEmpty) + } + assert(jobEnded.intValue == 0) + + // One job cancelled + for (i <- SQLExecution.executionIdToQueryExecution.keys().asScala) { + sessionC.interruptOperation(i.toString) + } + val eC = intercept[SparkException] { + ThreadUtils.awaitResult(jobC, 1.minute) + }.getCause + assert(eC.getMessage contains "cancelled") + assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES)) + assert(jobEnded.intValue == 1) + + // Another job cancelled + assert(sessionA.interruptTag("one").size == 1) + val eA = intercept[SparkException] { + ThreadUtils.awaitResult(jobA, 1.minute) + }.getCause + assert(eA.getMessage contains "cancelled job tags one") + assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES)) + assert(jobEnded.intValue == 2) + + // The last job cancelled + sessionB.interruptAll() + val eB = intercept[SparkException] { + ThreadUtils.awaitResult(jobB, 1.minute) + }.getCause + assert(eB.getMessage contains "cancellation of all jobs") + assert(sem.tryAcquire(1, 1, TimeUnit.MINUTES)) + assert(jobEnded.intValue == 3) + } finally { + fpool.shutdownNow() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala index 94d33731b6de..059a4c9b8376 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala @@ -228,7 +228,7 @@ class SQLExecutionSuite extends SparkFunSuite with SQLConfHelper { spark.range(1).collect() spark.sparkContext.listenerBus.waitUntilEmpty() - assert(jobTags.contains(jobTag)) + assert(jobTags.get.contains(jobTag)) assert(sqlJobTags.contains(jobTag)) } finally { spark.sparkContext.removeJobTag(jobTag)