Skip to content

Commit d0048d7

Browse files
fsamuel-bsjdcasale
authored andcommitted
[SPARK-33088][CORE] Enhance ExecutorPlugin API to include callbacks on task start and end events
Proposing a new set of APIs for ExecutorPlugins, to provide callbacks invoked at the start and end of each task of a job. Not very opinionated on the shape of the API, tried to be as minimal as possible for now. Changes described in detail on [SPARK-33088](https://issues.apache.org/jira/browse/SPARK-33088), but mostly this boils down to: 1. This feature was considered when the ExecutorPlugin API was initially introduced in apache#21923, but never implemented. 2. The use-case which **requires** this feature is to propagate tracing information from the driver to the executor, such that calls from the same job can all be traced. a. Tracing frameworks usually are setup in thread locals, therefore it's important for the setup to happen in the same thread which runs the tasks. b. Executors can be for multiple jobs, therefore it's not sufficient to set tracing information at executor startup time -- it needs to happen every time a task starts or ends. No. This PR introduces new features for future developers to use. Unit tests on `PluginContainerSuite`. Closes apache#29977 from fsamuel-bs/SPARK-33088. Authored-by: Samuel Souza <[email protected]> Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
1 parent 9ed1b5b commit d0048d7

File tree

6 files changed

+163
-14
lines changed

6 files changed

+163
-14
lines changed

core/src/main/java/org/apache/spark/api/plugin/ExecutorPlugin.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import java.util.Map;
2121

22+
import org.apache.spark.TaskFailedReason;
2223
import org.apache.spark.annotation.DeveloperApi;
2324

2425
/**
@@ -54,4 +55,45 @@ default void init(PluginContext ctx, Map<String, String> extraConf) {}
5455
*/
5556
default void shutdown() {}
5657

58+
/**
59+
* Perform any action before the task is run.
60+
* <p>
61+
* This method is invoked from the same thread the task will be executed.
62+
* Task-specific information can be accessed via {@link org.apache.spark.TaskContext#get}.
63+
* <p>
64+
* Plugin authors should avoid expensive operations here, as this method will be called
65+
* on every task, and doing something expensive can significantly slow down a job.
66+
* It is not recommended for a user to call a remote service, for example.
67+
* <p>
68+
* Exceptions thrown from this method do not propagate - they're caught,
69+
* logged, and suppressed. Therefore exceptions when executing this method won't
70+
* make the job fail.
71+
*
72+
* @since 3.1.0
73+
*/
74+
default void onTaskStart() {}
75+
76+
/**
77+
* Perform an action after tasks completes without exceptions.
78+
* <p>
79+
* As {@link #onTaskStart() onTaskStart} exceptions are suppressed, this method
80+
* will still be invoked even if the corresponding {@link #onTaskStart} call for this
81+
* task failed.
82+
* <p>
83+
* Same warnings of {@link #onTaskStart() onTaskStart} apply here.
84+
*
85+
* @since 3.1.0
86+
*/
87+
default void onTaskSucceeded() {}
88+
89+
/**
90+
* Perform an action after tasks completes with exceptions.
91+
* <p>
92+
* Same warnings of {@link #onTaskStart() onTaskStart} apply here.
93+
*
94+
* @param failureReason the exception thrown from the failed task.
95+
*
96+
* @since 3.1.0
97+
*/
98+
default void onTaskFailed(TaskFailedReason failureReason) {}
5799
}

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ private[spark] class Executor(
228228
private[executor] def numRunningTasks: Int = runningTasks.size()
229229

230230
def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
231-
val tr = new TaskRunner(context, taskDescription)
231+
val tr = new TaskRunner(context, taskDescription, plugins)
232232
runningTasks.put(taskDescription.taskId, tr)
233233
threadPool.execute(tr)
234234
}
@@ -304,7 +304,8 @@ private[spark] class Executor(
304304

305305
class TaskRunner(
306306
execBackend: ExecutorBackend,
307-
private val taskDescription: TaskDescription)
307+
private val taskDescription: TaskDescription,
308+
private val plugins: Option[PluginContainer])
308309
extends Runnable {
309310

310311
val taskId = taskDescription.taskId
@@ -455,7 +456,8 @@ private[spark] class Executor(
455456
taskAttemptId = taskId,
456457
attemptNumber = taskDescription.attemptNumber,
457458
metricsSystem = env.metricsSystem,
458-
resources = taskDescription.resources)
459+
resources = taskDescription.resources,
460+
plugins = plugins)
459461
threwException = false
460462
res
461463
} {
@@ -607,6 +609,7 @@ private[spark] class Executor(
607609

608610
executorSource.SUCCEEDED_TASKS.inc(1L)
609611
setTaskFinishedAndClearInterruptStatus()
612+
plugins.foreach(_.onTaskSucceeded())
610613
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
611614
} catch {
612615
case t: TaskKilledException =>
@@ -619,8 +622,9 @@ private[spark] class Executor(
619622
// Here and below, put task metric peaks in a WrappedArray to expose them as a Seq
620623
// without requiring a copy.
621624
val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId))
622-
val serializedTK = ser.serialize(TaskKilled(t.reason, accUpdates, accums, metricPeaks))
623-
execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK)
625+
val reason = TaskKilled(t.reason, accUpdates, accums, metricPeaks)
626+
plugins.foreach(_.onTaskFailed(reason))
627+
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))
624628

625629
case _: InterruptedException | NonFatal(_) if
626630
task != null && task.reasonIfKilled.isDefined =>
@@ -632,8 +636,9 @@ private[spark] class Executor(
632636

633637
val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs)
634638
val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId))
635-
val serializedTK = ser.serialize(TaskKilled(killReason, accUpdates, accums, metricPeaks))
636-
execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK)
639+
val reason = TaskKilled(killReason, accUpdates, accums, metricPeaks)
640+
plugins.foreach(_.onTaskFailed(reason))
641+
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))
637642

638643
case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) =>
639644
val reason = task.context.fetchFailed.get.toTaskFailedReason
@@ -650,11 +655,13 @@ private[spark] class Executor(
650655
SafeArg.of("fetchFailedExceptionClass", fetchFailedCls))
651656
}
652657
setTaskFinishedAndClearInterruptStatus()
658+
plugins.foreach(_.onTaskFailed(reason))
653659
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
654660

655661
case CausedBy(cDE: CommitDeniedException) =>
656662
val reason = cDE.toTaskCommitDeniedReason
657663
setTaskFinishedAndClearInterruptStatus()
664+
plugins.foreach(_.onTaskFailed(reason))
658665
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))
659666

660667
case t: Throwable if env.isStopped =>
@@ -683,21 +690,23 @@ private[spark] class Executor(
683690
val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs)
684691
val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId))
685692

686-
val serializedTaskEndReason = {
693+
val (taskFailureReason, serializedTaskFailureReason) = {
687694
try {
688695
val ef = new ExceptionFailure(t, accUpdates).withAccums(accums)
689696
.withMetricPeaks(metricPeaks)
690697
ser.serialize(ef)
698+
(ef, ser.serialize(ef))
691699
} catch {
692700
case _: NotSerializableException =>
693701
// t is not serializable so just send the stacktrace
694702
val ef = new ExceptionFailure(t, accUpdates, false).withAccums(accums)
695703
.withMetricPeaks(metricPeaks)
696-
ser.serialize(ef)
704+
(ef, ser.serialize(ef))
697705
}
698706
}
699707
setTaskFinishedAndClearInterruptStatus()
700-
execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
708+
plugins.foreach(_.onTaskFailed(taskFailureReason))
709+
execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskFailureReason)
701710
} else {
702711
safeLogInfo("Not reporting error to driver during JVM shutdown.")
703712
}

core/src/main/scala/org/apache/spark/internal/plugin/PluginContainer.scala

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.internal.plugin
2020
import scala.collection.JavaConverters._
2121
import scala.util.{Either, Left, Right}
2222

23-
import org.apache.spark.{SparkContext, SparkEnv}
23+
import org.apache.spark.{SparkContext, SparkEnv, TaskFailedReason}
2424
import org.apache.spark.api.plugin._
2525
import org.apache.spark.internal.Logging
2626
import org.apache.spark.internal.config._
@@ -31,6 +31,9 @@ sealed abstract class PluginContainer {
3131

3232
def shutdown(): Unit
3333
def registerMetrics(appId: String): Unit
34+
def onTaskStart(): Unit
35+
def onTaskSucceeded(): Unit
36+
def onTaskFailed(failureReason: TaskFailedReason): Unit
3437

3538
}
3639

@@ -85,6 +88,17 @@ private class DriverPluginContainer(
8588
}
8689
}
8790

91+
override def onTaskStart(): Unit = {
92+
throw new IllegalStateException("Should not be called for the driver container.")
93+
}
94+
95+
override def onTaskSucceeded(): Unit = {
96+
throw new IllegalStateException("Should not be called for the driver container.")
97+
}
98+
99+
override def onTaskFailed(failureReason: TaskFailedReason): Unit = {
100+
throw new IllegalStateException("Should not be called for the driver container.")
101+
}
88102
}
89103

90104
private class ExecutorPluginContainer(
@@ -134,6 +148,39 @@ private class ExecutorPluginContainer(
134148
}
135149
}
136150
}
151+
152+
override def onTaskStart(): Unit = {
153+
executorPlugins.foreach { case (name, plugin) =>
154+
try {
155+
plugin.onTaskStart()
156+
} catch {
157+
case t: Throwable =>
158+
logInfo(s"Exception while calling onTaskStart on plugin $name.", t)
159+
}
160+
}
161+
}
162+
163+
override def onTaskSucceeded(): Unit = {
164+
executorPlugins.foreach { case (name, plugin) =>
165+
try {
166+
plugin.onTaskSucceeded()
167+
} catch {
168+
case t: Throwable =>
169+
logInfo(s"Exception while calling onTaskSucceeded on plugin $name.", t)
170+
}
171+
}
172+
}
173+
174+
override def onTaskFailed(failureReason: TaskFailedReason): Unit = {
175+
executorPlugins.foreach { case (name, plugin) =>
176+
try {
177+
plugin.onTaskFailed(failureReason)
178+
} catch {
179+
case t: Throwable =>
180+
logInfo(s"Exception while calling onTaskFailed on plugin $name.", t)
181+
}
182+
}
183+
}
137184
}
138185

139186
object PluginContainer {

core/src/main/scala/org/apache/spark/scheduler/Task.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import java.util.Properties
2323
import org.apache.spark._
2424
import org.apache.spark.executor.TaskMetrics
2525
import org.apache.spark.internal.config.APP_CALLER_CONTEXT
26+
import org.apache.spark.internal.plugin.PluginContainer
2627
import org.apache.spark.memory.{MemoryMode, TaskMemoryManager}
2728
import org.apache.spark.metrics.MetricsSystem
2829
import org.apache.spark.rdd.InputFileBlockHolder
@@ -82,7 +83,8 @@ private[spark] abstract class Task[T](
8283
taskAttemptId: Long,
8384
attemptNumber: Int,
8485
metricsSystem: MetricsSystem,
85-
resources: Map[String, ResourceInformation]): T = {
86+
resources: Map[String, ResourceInformation],
87+
plugins: Option[PluginContainer]): T = {
8688
SparkEnv.get.blockManager.registerTask(taskAttemptId)
8789
// TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether
8890
// the stage is barrier.
@@ -123,6 +125,8 @@ private[spark] abstract class Task[T](
123125
Option(taskAttemptId),
124126
Option(attemptNumber)).setCurrentContext()
125127

128+
plugins.foreach(_.onTaskStart())
129+
126130
try {
127131
runTask(context)
128132
} catch {

core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,38 @@ class PluginContainerSuite extends SparkFunSuite with BeforeAndAfterEach with Lo
129129
assert(TestSparkPlugin.driverPlugin != null)
130130
}
131131

132+
test("SPARK-33088: executor tasks trigger plugin calls") {
133+
val conf = new SparkConf()
134+
.setAppName(getClass().getName())
135+
.set(SparkLauncher.SPARK_MASTER, "local[1]")
136+
.set(PLUGINS, Seq(classOf[TestSparkPlugin].getName()))
137+
138+
sc = new SparkContext(conf)
139+
sc.parallelize(1 to 10, 2).count()
140+
141+
assert(TestSparkPlugin.executorPlugin.numOnTaskStart == 2)
142+
assert(TestSparkPlugin.executorPlugin.numOnTaskSucceeded == 2)
143+
assert(TestSparkPlugin.executorPlugin.numOnTaskFailed == 0)
144+
}
145+
146+
test("SPARK-33088: executor failed tasks trigger plugin calls") {
147+
val conf = new SparkConf()
148+
.setAppName(getClass().getName())
149+
.set(SparkLauncher.SPARK_MASTER, "local[1]")
150+
.set(PLUGINS, Seq(classOf[TestSparkPlugin].getName()))
151+
152+
sc = new SparkContext(conf)
153+
try {
154+
sc.parallelize(1 to 10, 2).foreach(i => throw new RuntimeException)
155+
} catch {
156+
case t: Throwable => // ignore exception
157+
}
158+
159+
assert(TestSparkPlugin.executorPlugin.numOnTaskStart == 2)
160+
assert(TestSparkPlugin.executorPlugin.numOnTaskSucceeded == 0)
161+
assert(TestSparkPlugin.executorPlugin.numOnTaskFailed == 2)
162+
}
163+
132164
test("plugin initialization in non-local mode") {
133165
val path = Utils.createTempDir()
134166

@@ -309,13 +341,28 @@ private class TestDriverPlugin extends DriverPlugin {
309341

310342
private class TestExecutorPlugin extends ExecutorPlugin {
311343

344+
var numOnTaskStart: Int = 0
345+
var numOnTaskSucceeded: Int = 0
346+
var numOnTaskFailed: Int = 0
347+
312348
override def init(ctx: PluginContext, extraConf: JMap[String, String]): Unit = {
313349
ctx.metricRegistry().register("executorMetric", new Gauge[Int] {
314350
override def getValue(): Int = 84
315351
})
316352
TestSparkPlugin.executorContext = ctx
317353
}
318354

355+
override def onTaskStart(): Unit = {
356+
numOnTaskStart += 1
357+
}
358+
359+
override def onTaskSucceeded(): Unit = {
360+
numOnTaskSucceeded += 1
361+
}
362+
363+
override def onTaskFailed(failureReason: TaskFailedReason): Unit = {
364+
numOnTaskFailed += 1
365+
}
319366
}
320367

321368
private object TestSparkPlugin {

core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
7070
0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties,
7171
closureSerializer.serialize(TaskMetrics.registered).array())
7272
intercept[RuntimeException] {
73-
task.run(0, 0, null, null)
73+
task.run(0, 0, null, null, Option.empty)
7474
}
7575
assert(TaskContextSuite.completed)
7676
}
@@ -92,7 +92,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
9292
0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties,
9393
closureSerializer.serialize(TaskMetrics.registered).array())
9494
intercept[RuntimeException] {
95-
task.run(0, 0, null, null)
95+
task.run(0, 0, null, null, Option.empty)
9696
}
9797
assert(TaskContextSuite.lastError.getMessage == "damn error")
9898
}

0 commit comments

Comments
 (0)