Skip to content

Commit 8f4fc22

Browse files
fsamuel-bsMridul Muralidharan
authored andcommitted
[SPARK-33088][CORE] Enhance ExecutorPlugin API to include callbacks on task start and end events
### What changes were proposed in this pull request? Proposing a new set of APIs for ExecutorPlugins, to provide callbacks invoked at the start and end of each task of a job. Not very opinionated on the shape of the API, tried to be as minimal as possible for now. ### Why are the changes needed? Changes described in detail on [SPARK-33088](https://issues.apache.org/jira/browse/SPARK-33088), but mostly this boils down to: 1. This feature was considered when the ExecutorPlugin API was initially introduced in #21923, but never implemented. 2. The use-case which **requires** this feature is to propagate tracing information from the driver to the executor, such that calls from the same job can all be traced. a. Tracing frameworks usually are setup in thread locals, therefore it's important for the setup to happen in the same thread which runs the tasks. b. Executors can be for multiple jobs, therefore it's not sufficient to set tracing information at executor startup time -- it needs to happen every time a task starts or ends. ### Does this PR introduce _any_ user-facing change? No. This PR introduces new features for future developers to use. ### How was this patch tested? Unit tests on `PluginContainerSuite`. Closes #29977 from fsamuel-bs/SPARK-33088. Authored-by: Samuel Souza <[email protected]> Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
1 parent a5c17de commit 8f4fc22

File tree

6 files changed

+163
-17
lines changed

6 files changed

+163
-17
lines changed

core/src/main/java/org/apache/spark/api/plugin/ExecutorPlugin.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import java.util.Map;
2121

22+
import org.apache.spark.TaskFailedReason;
2223
import org.apache.spark.annotation.DeveloperApi;
2324

2425
/**
@@ -54,4 +55,45 @@ default void init(PluginContext ctx, Map<String, String> extraConf) {}
5455
*/
5556
default void shutdown() {}
5657

58+
/**
59+
* Perform any action before the task is run.
60+
* <p>
61+
* This method is invoked from the same thread the task will be executed.
62+
* Task-specific information can be accessed via {@link org.apache.spark.TaskContext#get}.
63+
* <p>
64+
* Plugin authors should avoid expensive operations here, as this method will be called
65+
* on every task, and doing something expensive can significantly slow down a job.
66+
* It is not recommended for a user to call a remote service, for example.
67+
* <p>
68+
* Exceptions thrown from this method do not propagate - they're caught,
69+
* logged, and suppressed. Therefore exceptions when executing this method won't
70+
* make the job fail.
71+
*
72+
* @since 3.1.0
73+
*/
74+
default void onTaskStart() {}
75+
76+
/**
77+
* Perform an action after tasks completes without exceptions.
78+
* <p>
79+
* As {@link #onTaskStart() onTaskStart} exceptions are suppressed, this method
80+
* will still be invoked even if the corresponding {@link #onTaskStart} call for this
81+
* task failed.
82+
* <p>
83+
* Same warnings of {@link #onTaskStart() onTaskStart} apply here.
84+
*
85+
* @since 3.1.0
86+
*/
87+
default void onTaskSucceeded() {}
88+
89+
/**
90+
* Perform an action after tasks completes with exceptions.
91+
* <p>
92+
* Same warnings of {@link #onTaskStart() onTaskStart} apply here.
93+
*
94+
* @param failureReason the exception thrown from the failed task.
95+
*
96+
* @since 3.1.0
97+
*/
98+
default void onTaskFailed(TaskFailedReason failureReason) {}
5799
}

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ private[spark] class Executor(
253253
}
254254

255255
def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
256-
val tr = new TaskRunner(context, taskDescription)
256+
val tr = new TaskRunner(context, taskDescription, plugins)
257257
runningTasks.put(taskDescription.taskId, tr)
258258
threadPool.execute(tr)
259259
if (decommissioned) {
@@ -332,7 +332,8 @@ private[spark] class Executor(
332332

333333
class TaskRunner(
334334
execBackend: ExecutorBackend,
335-
private val taskDescription: TaskDescription)
335+
private val taskDescription: TaskDescription,
336+
private val plugins: Option[PluginContainer])
336337
extends Runnable {
337338

338339
val taskId = taskDescription.taskId
@@ -479,7 +480,8 @@ private[spark] class Executor(
479480
taskAttemptId = taskId,
480481
attemptNumber = taskDescription.attemptNumber,
481482
metricsSystem = env.metricsSystem,
482-
resources = taskDescription.resources)
483+
resources = taskDescription.resources,
484+
plugins = plugins)
483485
threwException = false
484486
res
485487
} {
@@ -614,6 +616,7 @@ private[spark] class Executor(
614616

615617
executorSource.SUCCEEDED_TASKS.inc(1L)
616618
setTaskFinishedAndClearInterruptStatus()
619+
plugins.foreach(_.onTaskSucceeded())
617620
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
618621
} catch {
619622
case t: TaskKilledException =>
@@ -623,9 +626,9 @@ private[spark] class Executor(
623626
// Here and below, put task metric peaks in a WrappedArray to expose them as a Seq
624627
// without requiring a copy.
625628
val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId))
626-
val serializedTK = ser.serialize(
627-
TaskKilled(t.reason, accUpdates, accums, metricPeaks.toSeq))
628-
execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK)
629+
val reason = TaskKilled(t.reason, accUpdates, accums, metricPeaks.toSeq)
630+
plugins.foreach(_.onTaskFailed(reason))
631+
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))
629632

630633
case _: InterruptedException | NonFatal(_) if
631634
task != null && task.reasonIfKilled.isDefined =>
@@ -634,9 +637,9 @@ private[spark] class Executor(
634637

635638
val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs)
636639
val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId))
637-
val serializedTK = ser.serialize(
638-
TaskKilled(killReason, accUpdates, accums, metricPeaks.toSeq))
639-
execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK)
640+
val reason = TaskKilled(killReason, accUpdates, accums, metricPeaks.toSeq)
641+
plugins.foreach(_.onTaskFailed(reason))
642+
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))
640643

641644
case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) =>
642645
val reason = task.context.fetchFailed.get.toTaskFailedReason
@@ -650,11 +653,13 @@ private[spark] class Executor(
650653
s"other exception: $t")
651654
}
652655
setTaskFinishedAndClearInterruptStatus()
656+
plugins.foreach(_.onTaskFailed(reason))
653657
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
654658

655659
case CausedBy(cDE: CommitDeniedException) =>
656660
val reason = cDE.toTaskCommitDeniedReason
657661
setTaskFinishedAndClearInterruptStatus()
662+
plugins.foreach(_.onTaskFailed(reason))
658663
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))
659664

660665
case t: Throwable if env.isStopped =>
@@ -677,21 +682,22 @@ private[spark] class Executor(
677682
val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs)
678683
val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId))
679684

680-
val serializedTaskEndReason = {
685+
val (taskFailureReason, serializedTaskFailureReason) = {
681686
try {
682687
val ef = new ExceptionFailure(t, accUpdates).withAccums(accums)
683688
.withMetricPeaks(metricPeaks.toSeq)
684-
ser.serialize(ef)
689+
(ef, ser.serialize(ef))
685690
} catch {
686691
case _: NotSerializableException =>
687692
// t is not serializable so just send the stacktrace
688693
val ef = new ExceptionFailure(t, accUpdates, false).withAccums(accums)
689694
.withMetricPeaks(metricPeaks.toSeq)
690-
ser.serialize(ef)
695+
(ef, ser.serialize(ef))
691696
}
692697
}
693698
setTaskFinishedAndClearInterruptStatus()
694-
execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
699+
plugins.foreach(_.onTaskFailed(taskFailureReason))
700+
execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskFailureReason)
695701
} else {
696702
logInfo("Not reporting error to driver during JVM shutdown.")
697703
}

core/src/main/scala/org/apache/spark/internal/plugin/PluginContainer.scala

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.internal.plugin
2020
import scala.collection.JavaConverters._
2121
import scala.util.{Either, Left, Right}
2222

23-
import org.apache.spark.{SparkContext, SparkEnv}
23+
import org.apache.spark.{SparkContext, SparkEnv, TaskFailedReason}
2424
import org.apache.spark.api.plugin._
2525
import org.apache.spark.internal.Logging
2626
import org.apache.spark.internal.config._
@@ -31,6 +31,9 @@ sealed abstract class PluginContainer {
3131

3232
def shutdown(): Unit
3333
def registerMetrics(appId: String): Unit
34+
def onTaskStart(): Unit
35+
def onTaskSucceeded(): Unit
36+
def onTaskFailed(failureReason: TaskFailedReason): Unit
3437

3538
}
3639

@@ -85,6 +88,17 @@ private class DriverPluginContainer(
8588
}
8689
}
8790

91+
override def onTaskStart(): Unit = {
92+
throw new IllegalStateException("Should not be called for the driver container.")
93+
}
94+
95+
override def onTaskSucceeded(): Unit = {
96+
throw new IllegalStateException("Should not be called for the driver container.")
97+
}
98+
99+
override def onTaskFailed(failureReason: TaskFailedReason): Unit = {
100+
throw new IllegalStateException("Should not be called for the driver container.")
101+
}
88102
}
89103

90104
private class ExecutorPluginContainer(
@@ -134,6 +148,39 @@ private class ExecutorPluginContainer(
134148
}
135149
}
136150
}
151+
152+
override def onTaskStart(): Unit = {
153+
executorPlugins.foreach { case (name, plugin) =>
154+
try {
155+
plugin.onTaskStart()
156+
} catch {
157+
case t: Throwable =>
158+
logInfo(s"Exception while calling onTaskStart on plugin $name.", t)
159+
}
160+
}
161+
}
162+
163+
override def onTaskSucceeded(): Unit = {
164+
executorPlugins.foreach { case (name, plugin) =>
165+
try {
166+
plugin.onTaskSucceeded()
167+
} catch {
168+
case t: Throwable =>
169+
logInfo(s"Exception while calling onTaskSucceeded on plugin $name.", t)
170+
}
171+
}
172+
}
173+
174+
override def onTaskFailed(failureReason: TaskFailedReason): Unit = {
175+
executorPlugins.foreach { case (name, plugin) =>
176+
try {
177+
plugin.onTaskFailed(failureReason)
178+
} catch {
179+
case t: Throwable =>
180+
logInfo(s"Exception while calling onTaskFailed on plugin $name.", t)
181+
}
182+
}
183+
}
137184
}
138185

139186
object PluginContainer {

core/src/main/scala/org/apache/spark/scheduler/Task.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import java.util.Properties
2323
import org.apache.spark._
2424
import org.apache.spark.executor.TaskMetrics
2525
import org.apache.spark.internal.config.APP_CALLER_CONTEXT
26+
import org.apache.spark.internal.plugin.PluginContainer
2627
import org.apache.spark.memory.{MemoryMode, TaskMemoryManager}
2728
import org.apache.spark.metrics.MetricsSystem
2829
import org.apache.spark.rdd.InputFileBlockHolder
@@ -82,7 +83,8 @@ private[spark] abstract class Task[T](
8283
taskAttemptId: Long,
8384
attemptNumber: Int,
8485
metricsSystem: MetricsSystem,
85-
resources: Map[String, ResourceInformation]): T = {
86+
resources: Map[String, ResourceInformation],
87+
plugins: Option[PluginContainer]): T = {
8688
SparkEnv.get.blockManager.registerTask(taskAttemptId)
8789
// TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether
8890
// the stage is barrier.
@@ -123,6 +125,8 @@ private[spark] abstract class Task[T](
123125
Option(taskAttemptId),
124126
Option(attemptNumber)).setCurrentContext()
125127

128+
plugins.foreach(_.onTaskStart())
129+
126130
try {
127131
runTask(context)
128132
} catch {

core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,38 @@ class PluginContainerSuite extends SparkFunSuite with BeforeAndAfterEach with Lo
129129
assert(TestSparkPlugin.driverPlugin != null)
130130
}
131131

132+
test("SPARK-33088: executor tasks trigger plugin calls") {
133+
val conf = new SparkConf()
134+
.setAppName(getClass().getName())
135+
.set(SparkLauncher.SPARK_MASTER, "local[1]")
136+
.set(PLUGINS, Seq(classOf[TestSparkPlugin].getName()))
137+
138+
sc = new SparkContext(conf)
139+
sc.parallelize(1 to 10, 2).count()
140+
141+
assert(TestSparkPlugin.executorPlugin.numOnTaskStart == 2)
142+
assert(TestSparkPlugin.executorPlugin.numOnTaskSucceeded == 2)
143+
assert(TestSparkPlugin.executorPlugin.numOnTaskFailed == 0)
144+
}
145+
146+
test("SPARK-33088: executor failed tasks trigger plugin calls") {
147+
val conf = new SparkConf()
148+
.setAppName(getClass().getName())
149+
.set(SparkLauncher.SPARK_MASTER, "local[1]")
150+
.set(PLUGINS, Seq(classOf[TestSparkPlugin].getName()))
151+
152+
sc = new SparkContext(conf)
153+
try {
154+
sc.parallelize(1 to 10, 2).foreach(i => throw new RuntimeException)
155+
} catch {
156+
case t: Throwable => // ignore exception
157+
}
158+
159+
assert(TestSparkPlugin.executorPlugin.numOnTaskStart == 2)
160+
assert(TestSparkPlugin.executorPlugin.numOnTaskSucceeded == 0)
161+
assert(TestSparkPlugin.executorPlugin.numOnTaskFailed == 2)
162+
}
163+
132164
test("plugin initialization in non-local mode") {
133165
val path = Utils.createTempDir()
134166

@@ -309,13 +341,28 @@ private class TestDriverPlugin extends DriverPlugin {
309341

310342
private class TestExecutorPlugin extends ExecutorPlugin {
311343

344+
var numOnTaskStart: Int = 0
345+
var numOnTaskSucceeded: Int = 0
346+
var numOnTaskFailed: Int = 0
347+
312348
override def init(ctx: PluginContext, extraConf: JMap[String, String]): Unit = {
313349
ctx.metricRegistry().register("executorMetric", new Gauge[Int] {
314350
override def getValue(): Int = 84
315351
})
316352
TestSparkPlugin.executorContext = ctx
317353
}
318354

355+
override def onTaskStart(): Unit = {
356+
numOnTaskStart += 1
357+
}
358+
359+
override def onTaskSucceeded(): Unit = {
360+
numOnTaskSucceeded += 1
361+
}
362+
363+
override def onTaskFailed(failureReason: TaskFailedReason): Unit = {
364+
numOnTaskFailed += 1
365+
}
319366
}
320367

321368
private object TestSparkPlugin {

core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
7070
0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties,
7171
closureSerializer.serialize(TaskMetrics.registered).array())
7272
intercept[RuntimeException] {
73-
task.run(0, 0, null, null)
73+
task.run(0, 0, null, null, Option.empty)
7474
}
7575
assert(TaskContextSuite.completed)
7676
}
@@ -92,7 +92,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
9292
0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties,
9393
closureSerializer.serialize(TaskMetrics.registered).array())
9494
intercept[RuntimeException] {
95-
task.run(0, 0, null, null)
95+
task.run(0, 0, null, null, Option.empty)
9696
}
9797
assert(TaskContextSuite.lastError.getMessage == "damn error")
9898
}

0 commit comments

Comments
 (0)