-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-33088][CORE] Enhance ExecutorPlugin API to include callbacks on task start and end events #29977
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-33088][CORE] Enhance ExecutorPlugin API to include callbacks on task start and end events #29977
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -253,7 +253,7 @@ private[spark] class Executor( | |||||
| } | ||||||
|
|
||||||
| def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = { | ||||||
| val tr = new TaskRunner(context, taskDescription) | ||||||
| val tr = new TaskRunner(context, taskDescription, plugins) | ||||||
| runningTasks.put(taskDescription.taskId, tr) | ||||||
| threadPool.execute(tr) | ||||||
| if (decommissioned) { | ||||||
|
|
@@ -332,7 +332,8 @@ private[spark] class Executor( | |||||
|
|
||||||
| class TaskRunner( | ||||||
| execBackend: ExecutorBackend, | ||||||
| private val taskDescription: TaskDescription) | ||||||
| private val taskDescription: TaskDescription, | ||||||
| private val plugins: Option[PluginContainer]) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can make the
Suggested change
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same for
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @rshkv, what is the reason to make this default to None? This is an internal api and only called from here. It's an option already so people can check it easily. In some ways its nice to force it so you make sure all uses of it have been updated. |
||||||
| extends Runnable { | ||||||
|
|
||||||
| val taskId = taskDescription.taskId | ||||||
|
|
@@ -479,7 +480,8 @@ private[spark] class Executor( | |||||
| taskAttemptId = taskId, | ||||||
| attemptNumber = taskDescription.attemptNumber, | ||||||
| metricsSystem = env.metricsSystem, | ||||||
| resources = taskDescription.resources) | ||||||
| resources = taskDescription.resources, | ||||||
| plugins = plugins) | ||||||
| threwException = false | ||||||
| res | ||||||
| } { | ||||||
|
|
@@ -614,6 +616,7 @@ private[spark] class Executor( | |||||
|
|
||||||
| executorSource.SUCCEEDED_TASKS.inc(1L) | ||||||
| setTaskFinishedAndClearInterruptStatus() | ||||||
| plugins.foreach(_.onTaskSucceeded()) | ||||||
| execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) | ||||||
| } catch { | ||||||
| case t: TaskKilledException => | ||||||
|
|
@@ -623,9 +626,9 @@ private[spark] class Executor( | |||||
| // Here and below, put task metric peaks in a WrappedArray to expose them as a Seq | ||||||
| // without requiring a copy. | ||||||
| val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId)) | ||||||
| val serializedTK = ser.serialize( | ||||||
| TaskKilled(t.reason, accUpdates, accums, metricPeaks.toSeq)) | ||||||
| execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK) | ||||||
| val reason = TaskKilled(t.reason, accUpdates, accums, metricPeaks.toSeq) | ||||||
| plugins.foreach(_.onTaskFailed(reason)) | ||||||
| execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason)) | ||||||
|
|
||||||
| case _: InterruptedException | NonFatal(_) if | ||||||
| task != null && task.reasonIfKilled.isDefined => | ||||||
|
|
@@ -634,9 +637,9 @@ private[spark] class Executor( | |||||
|
|
||||||
| val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs) | ||||||
| val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId)) | ||||||
| val serializedTK = ser.serialize( | ||||||
| TaskKilled(killReason, accUpdates, accums, metricPeaks.toSeq)) | ||||||
| execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK) | ||||||
| val reason = TaskKilled(killReason, accUpdates, accums, metricPeaks.toSeq) | ||||||
| plugins.foreach(_.onTaskFailed(reason)) | ||||||
| execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason)) | ||||||
|
|
||||||
| case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) => | ||||||
| val reason = task.context.fetchFailed.get.toTaskFailedReason | ||||||
|
|
@@ -650,11 +653,13 @@ private[spark] class Executor( | |||||
| s"other exception: $t") | ||||||
| } | ||||||
| setTaskFinishedAndClearInterruptStatus() | ||||||
| plugins.foreach(_.onTaskFailed(reason)) | ||||||
| execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) | ||||||
|
|
||||||
| case CausedBy(cDE: CommitDeniedException) => | ||||||
| val reason = cDE.toTaskCommitDeniedReason | ||||||
| setTaskFinishedAndClearInterruptStatus() | ||||||
| plugins.foreach(_.onTaskFailed(reason)) | ||||||
| execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason)) | ||||||
|
|
||||||
| case t: Throwable if env.isStopped => | ||||||
|
|
@@ -677,21 +682,22 @@ private[spark] class Executor( | |||||
| val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs) | ||||||
| val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId)) | ||||||
|
|
||||||
| val serializedTaskEndReason = { | ||||||
| val (taskFailureReason, serializedTaskFailureReason) = { | ||||||
| try { | ||||||
| val ef = new ExceptionFailure(t, accUpdates).withAccums(accums) | ||||||
| .withMetricPeaks(metricPeaks.toSeq) | ||||||
| ser.serialize(ef) | ||||||
| (ef, ser.serialize(ef)) | ||||||
| } catch { | ||||||
| case _: NotSerializableException => | ||||||
| // t is not serializable so just send the stacktrace | ||||||
| val ef = new ExceptionFailure(t, accUpdates, false).withAccums(accums) | ||||||
| .withMetricPeaks(metricPeaks.toSeq) | ||||||
| ser.serialize(ef) | ||||||
| (ef, ser.serialize(ef)) | ||||||
| } | ||||||
| } | ||||||
| setTaskFinishedAndClearInterruptStatus() | ||||||
| execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason) | ||||||
| plugins.foreach(_.onTaskFailed(taskFailureReason)) | ||||||
| execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskFailureReason) | ||||||
| } else { | ||||||
| logInfo("Not reporting error to driver during JVM shutdown.") | ||||||
| } | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ import java.util.Properties | |
| import org.apache.spark._ | ||
| import org.apache.spark.executor.TaskMetrics | ||
| import org.apache.spark.internal.config.APP_CALLER_CONTEXT | ||
| import org.apache.spark.internal.plugin.PluginContainer | ||
| import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} | ||
| import org.apache.spark.metrics.MetricsSystem | ||
| import org.apache.spark.rdd.InputFileBlockHolder | ||
|
|
@@ -82,7 +83,8 @@ private[spark] abstract class Task[T]( | |
| taskAttemptId: Long, | ||
| attemptNumber: Int, | ||
| metricsSystem: MetricsSystem, | ||
| resources: Map[String, ResourceInformation]): T = { | ||
| resources: Map[String, ResourceInformation], | ||
| plugins: Option[PluginContainer]): T = { | ||
| SparkEnv.get.blockManager.registerTask(taskAttemptId) | ||
| // TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether | ||
| // the stage is barrier. | ||
|
|
@@ -123,6 +125,8 @@ private[spark] abstract class Task[T]( | |
| Option(taskAttemptId), | ||
| Option(attemptNumber)).setCurrentContext() | ||
|
|
||
| plugins.foreach(_.onTaskStart()) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the expectation in case
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's what I documented on https://github.com/apache/spark/pull/29977/files#diff-6a99ec9983962323b4e0c1899134b5d6R76-R78 -- argument that came to mind is that it's easy for a plugin dev to track some state in a thread-local and clean decide if it wants to perform the succeeded/failed action or not. Happy to change it if we prefer not to put this burden on the plugin owner though.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe I'm misunderstanding but the documentation states "Exceptions thrown from this method do not propagate", there is nothing here preventing that. I think perhaps you meant to say the user needs to make sure they don't propagate?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We catch Throwable on
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps reword it to say exceptions are ignored ? |
||
|
|
||
| try { | ||
| runTask(context) | ||
| } catch { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -129,6 +129,38 @@ class PluginContainerSuite extends SparkFunSuite with BeforeAndAfterEach with Lo | |
| assert(TestSparkPlugin.driverPlugin != null) | ||
| } | ||
|
|
||
| test("SPARK-33088: executor tasks trigger plugin calls") { | ||
| val conf = new SparkConf() | ||
| .setAppName(getClass().getName()) | ||
| .set(SparkLauncher.SPARK_MASTER, "local[1]") | ||
| .set(PLUGINS, Seq(classOf[TestSparkPlugin].getName())) | ||
|
|
||
| sc = new SparkContext(conf) | ||
| sc.parallelize(1 to 10, 2).count() | ||
|
|
||
| assert(TestSparkPlugin.executorPlugin.numOnTaskStart == 2) | ||
| assert(TestSparkPlugin.executorPlugin.numOnTaskSucceeded == 2) | ||
| assert(TestSparkPlugin.executorPlugin.numOnTaskFailed == 0) | ||
| } | ||
|
|
||
| test("SPARK-33088: executor failed tasks trigger plugin calls") { | ||
| val conf = new SparkConf() | ||
| .setAppName(getClass().getName()) | ||
| .set(SparkLauncher.SPARK_MASTER, "local[1]") | ||
| .set(PLUGINS, Seq(classOf[TestSparkPlugin].getName())) | ||
|
|
||
| sc = new SparkContext(conf) | ||
| try { | ||
| sc.parallelize(1 to 10, 2).foreach(i => throw new RuntimeException) | ||
| } catch { | ||
| case t: Throwable => // ignore exception | ||
| } | ||
|
|
||
| assert(TestSparkPlugin.executorPlugin.numOnTaskStart == 2) | ||
| assert(TestSparkPlugin.executorPlugin.numOnTaskSucceeded == 0) | ||
| assert(TestSparkPlugin.executorPlugin.numOnTaskFailed == 2) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi, folks.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @dongjoon-hyun |
||
| } | ||
|
|
||
| test("plugin initialization in non-local mode") { | ||
| val path = Utils.createTempDir() | ||
|
|
||
|
|
@@ -309,13 +341,28 @@ private class TestDriverPlugin extends DriverPlugin { | |
|
|
||
| private class TestExecutorPlugin extends ExecutorPlugin { | ||
|
|
||
| var numOnTaskStart: Int = 0 | ||
| var numOnTaskSucceeded: Int = 0 | ||
| var numOnTaskFailed: Int = 0 | ||
|
|
||
| override def init(ctx: PluginContext, extraConf: JMap[String, String]): Unit = { | ||
| ctx.metricRegistry().register("executorMetric", new Gauge[Int] { | ||
| override def getValue(): Int = 84 | ||
| }) | ||
| TestSparkPlugin.executorContext = ctx | ||
| } | ||
|
|
||
| override def onTaskStart(): Unit = { | ||
| numOnTaskStart += 1 | ||
| } | ||
|
|
||
| override def onTaskSucceeded(): Unit = { | ||
| numOnTaskSucceeded += 1 | ||
| } | ||
|
|
||
| override def onTaskFailed(failureReason: TaskFailedReason): Unit = { | ||
| numOnTaskFailed += 1 | ||
| } | ||
| } | ||
|
|
||
| private object TestSparkPlugin { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.