Skip to content

Commit 5d8d086

Browse files
fsamuel-bsrshkv
authored andcommitted
[SPARK-33089][CORE] Enhance ExecutorPlugin API to include callbacks on task start and end events
Proposing a new set of APIs for ExecutorPlugins, to provide callbacks invoked at the start and end of each task of a job. Not very opinionated on the shape of the API, tried to be as minimal as possible for now. Changes described in detail on [SPARK-33088](https://issues.apache.org/jira/browse/SPARK-33088), but mostly this boils down to: 1. This feature was considered when the ExecutorPlugin API was initially introduced in apache#21923, but never implemented. 2. The use-case which **requires** this feature is to propagate tracing information from the driver to the executor, such that calls from the same job can all be traced. a. Tracing frameworks usually are setup in thread locals, therefore it's important for the setup to happen in the same thread which runs the tasks. b. Executors can be for multiple jobs, therefore it's not sufficient to set tracing information at executor startup time -- it needs to happen every time a task starts or ends. No. This PR introduces new features for future developers to use. Unit tests on `PluginContainerSuite`. Closes apache#29977 from fsamuel-bs/SPARK-33088. Authored-by: Samuel Souza <[email protected]> Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
1 parent 5123c5a commit 5d8d086

File tree

6 files changed

+165
-15
lines changed

6 files changed

+165
-15
lines changed

core/src/main/java/org/apache/spark/api/plugin/ExecutorPlugin.java

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import java.util.Map;
2121

22+
import org.apache.spark.TaskFailedReason;
2223
import org.apache.spark.annotation.DeveloperApi;
2324

2425
/**
@@ -54,4 +55,45 @@ default void init(PluginContext ctx, Map<String, String> extraConf) {}
5455
*/
5556
default void shutdown() {}
5657

58+
/**
59+
* Perform any action before the task is run.
60+
* <p>
61+
* This method is invoked from the same thread the task will be executed.
62+
* Task-specific information can be accessed via {@link org.apache.spark.TaskContext#get}.
63+
* <p>
64+
* Plugin authors should avoid expensive operations here, as this method will be called
65+
* on every task, and doing something expensive can significantly slow down a job.
66+
* It is not recommended for a user to call a remote service, for example.
67+
* <p>
68+
* Exceptions thrown from this method do not propagate - they're caught,
69+
* logged, and suppressed. Therefore exceptions when executing this method won't
70+
* make the job fail.
71+
*
72+
* @since 3.1.0
73+
*/
74+
default void onTaskStart() {}
75+
76+
/**
77+
* Perform an action after tasks completes without exceptions.
78+
* <p>
79+
* As {@link #onTaskStart() onTaskStart} exceptions are suppressed, this method
80+
* will still be invoked even if the corresponding {@link #onTaskStart} call for this
81+
* task failed.
82+
* <p>
83+
* Same warnings of {@link #onTaskStart() onTaskStart} apply here.
84+
*
85+
* @since 3.1.0
86+
*/
87+
default void onTaskSucceeded() {}
88+
89+
/**
90+
* Perform an action after tasks completes with exceptions.
91+
* <p>
92+
* Same warnings of {@link #onTaskStart() onTaskStart} apply here.
93+
*
94+
* @param failureReason the exception thrown from the failed task.
95+
*
96+
* @since 3.1.0
97+
*/
98+
default void onTaskFailed(TaskFailedReason failureReason) {}
5799
}

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ private[spark] class Executor(
225225
private[executor] def numRunningTasks: Int = runningTasks.size()
226226

227227
def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
228-
val tr = new TaskRunner(context, taskDescription)
228+
val tr = new TaskRunner(context, taskDescription, plugins)
229229
runningTasks.put(taskDescription.taskId, tr)
230230
threadPool.execute(tr)
231231
}
@@ -301,7 +301,8 @@ private[spark] class Executor(
301301

302302
class TaskRunner(
303303
execBackend: ExecutorBackend,
304-
private val taskDescription: TaskDescription)
304+
private val taskDescription: TaskDescription,
305+
private val plugins: Option[PluginContainer])
305306
extends Runnable {
306307

307308
val taskId = taskDescription.taskId
@@ -443,7 +444,8 @@ private[spark] class Executor(
443444
taskAttemptId = taskId,
444445
attemptNumber = taskDescription.attemptNumber,
445446
metricsSystem = env.metricsSystem,
446-
resources = taskDescription.resources)
447+
resources = taskDescription.resources,
448+
plugins = plugins)
447449
threwException = false
448450
res
449451
} {
@@ -579,6 +581,7 @@ private[spark] class Executor(
579581

580582
executorSource.SUCCEEDED_TASKS.inc(1L)
581583
setTaskFinishedAndClearInterruptStatus()
584+
plugins.foreach(_.onTaskSucceeded())
582585
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
583586
} catch {
584587
case t: TaskKilledException =>
@@ -588,8 +591,9 @@ private[spark] class Executor(
588591
// Here and below, put task metric peaks in a WrappedArray to expose them as a Seq
589592
// without requiring a copy.
590593
val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId))
591-
val serializedTK = ser.serialize(TaskKilled(t.reason, accUpdates, accums, metricPeaks))
592-
execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK)
594+
val reason = TaskKilled(t.reason, accUpdates, accums, metricPeaks)
595+
plugins.foreach(_.onTaskFailed(reason))
596+
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))
593597

594598
case _: InterruptedException | NonFatal(_) if
595599
task != null && task.reasonIfKilled.isDefined =>
@@ -598,8 +602,9 @@ private[spark] class Executor(
598602

599603
val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs)
600604
val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId))
601-
val serializedTK = ser.serialize(TaskKilled(killReason, accUpdates, accums, metricPeaks))
602-
execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK)
605+
val reason = TaskKilled(killReason, accUpdates, accums, metricPeaks)
606+
plugins.foreach(_.onTaskFailed(reason))
607+
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))
603608

604609
case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) =>
605610
val reason = task.context.fetchFailed.get.toTaskFailedReason
@@ -613,11 +618,13 @@ private[spark] class Executor(
613618
s"other exception: $t")
614619
}
615620
setTaskFinishedAndClearInterruptStatus()
621+
plugins.foreach(_.onTaskFailed(reason))
616622
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
617623

618624
case CausedBy(cDE: CommitDeniedException) =>
619625
val reason = cDE.toTaskCommitDeniedReason
620626
setTaskFinishedAndClearInterruptStatus()
627+
plugins.foreach(_.onTaskFailed(reason))
621628
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(reason))
622629

623630
case t: Throwable if env.isStopped =>
@@ -640,21 +647,22 @@ private[spark] class Executor(
640647
val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTimeNs)
641648
val metricPeaks = WrappedArray.make(metricsPoller.getTaskMetricPeaks(taskId))
642649

643-
val serializedTaskEndReason = {
650+
val (taskFailureReason, serializedTaskFailureReason) = {
644651
try {
645652
val ef = new ExceptionFailure(t, accUpdates).withAccums(accums)
646653
.withMetricPeaks(metricPeaks)
647-
ser.serialize(ef)
654+
(ef, ser.serialize(ef))
648655
} catch {
649656
case _: NotSerializableException =>
650657
// t is not serializable so just send the stacktrace
651658
val ef = new ExceptionFailure(t, accUpdates, false).withAccums(accums)
652659
.withMetricPeaks(metricPeaks)
653-
ser.serialize(ef)
660+
(ef, ser.serialize(ef))
654661
}
655662
}
656663
setTaskFinishedAndClearInterruptStatus()
657-
execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
664+
plugins.foreach(_.onTaskFailed(taskFailureReason))
665+
execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskFailureReason)
658666
} else {
659667
logInfo("Not reporting error to driver during JVM shutdown.")
660668
}

core/src/main/scala/org/apache/spark/internal/plugin/PluginContainer.scala

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.internal.plugin
2020
import scala.collection.JavaConverters._
2121
import scala.util.{Either, Left, Right}
2222

23-
import org.apache.spark.{SparkContext, SparkEnv}
23+
import org.apache.spark.{SparkContext, SparkEnv, TaskFailedReason}
2424
import org.apache.spark.api.plugin._
2525
import org.apache.spark.internal.Logging
2626
import org.apache.spark.internal.config._
@@ -31,6 +31,9 @@ sealed abstract class PluginContainer {
3131

3232
def shutdown(): Unit
3333
def registerMetrics(appId: String): Unit
34+
def onTaskStart(): Unit
35+
def onTaskSucceeded(): Unit
36+
def onTaskFailed(failureReason: TaskFailedReason): Unit
3437

3538
}
3639

@@ -85,6 +88,17 @@ private class DriverPluginContainer(
8588
}
8689
}
8790

91+
override def onTaskStart(): Unit = {
92+
throw new IllegalStateException("Should not be called for the driver container.")
93+
}
94+
95+
override def onTaskSucceeded(): Unit = {
96+
throw new IllegalStateException("Should not be called for the driver container.")
97+
}
98+
99+
override def onTaskFailed(failureReason: TaskFailedReason): Unit = {
100+
throw new IllegalStateException("Should not be called for the driver container.")
101+
}
88102
}
89103

90104
private class ExecutorPluginContainer(
@@ -134,6 +148,39 @@ private class ExecutorPluginContainer(
134148
}
135149
}
136150
}
151+
152+
override def onTaskStart(): Unit = {
153+
executorPlugins.foreach { case (name, plugin) =>
154+
try {
155+
plugin.onTaskStart()
156+
} catch {
157+
case t: Throwable =>
158+
logInfo(s"Exception while calling onTaskStart on plugin $name.", t)
159+
}
160+
}
161+
}
162+
163+
override def onTaskSucceeded(): Unit = {
164+
executorPlugins.foreach { case (name, plugin) =>
165+
try {
166+
plugin.onTaskSucceeded()
167+
} catch {
168+
case t: Throwable =>
169+
logInfo(s"Exception while calling onTaskSucceeded on plugin $name.", t)
170+
}
171+
}
172+
}
173+
174+
override def onTaskFailed(failureReason: TaskFailedReason): Unit = {
175+
executorPlugins.foreach { case (name, plugin) =>
176+
try {
177+
plugin.onTaskFailed(failureReason)
178+
} catch {
179+
case t: Throwable =>
180+
logInfo(s"Exception while calling onTaskFailed on plugin $name.", t)
181+
}
182+
}
183+
}
137184
}
138185

139186
object PluginContainer {

core/src/main/scala/org/apache/spark/scheduler/Task.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import java.util.Properties
2323
import org.apache.spark._
2424
import org.apache.spark.executor.TaskMetrics
2525
import org.apache.spark.internal.config.APP_CALLER_CONTEXT
26+
import org.apache.spark.internal.plugin.PluginContainer
2627
import org.apache.spark.memory.{MemoryMode, TaskMemoryManager}
2728
import org.apache.spark.metrics.MetricsSystem
2829
import org.apache.spark.rdd.InputFileBlockHolder
@@ -82,7 +83,8 @@ private[spark] abstract class Task[T](
8283
taskAttemptId: Long,
8384
attemptNumber: Int,
8485
metricsSystem: MetricsSystem,
85-
resources: Map[String, ResourceInformation]): T = {
86+
resources: Map[String, ResourceInformation],
87+
plugins: Option[PluginContainer]): T = {
8688
SparkEnv.get.blockManager.registerTask(taskAttemptId)
8789
// TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether
8890
// the stage is barrier.
@@ -123,6 +125,8 @@ private[spark] abstract class Task[T](
123125
Option(taskAttemptId),
124126
Option(attemptNumber)).setCurrentContext()
125127

128+
plugins.foreach(_.onTaskStart())
129+
126130
try {
127131
runTask(context)
128132
} catch {

core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,40 @@ class PluginContainerSuite extends SparkFunSuite with BeforeAndAfterEach with Lo
129129
assert(TestSparkPlugin.driverPlugin != null)
130130
}
131131

132+
test("SPARK-33088: executor tasks trigger plugin calls") {
133+
val conf = new SparkConf()
134+
.setAppName(getClass().getName())
135+
.set(SparkLauncher.SPARK_MASTER, "local[1]")
136+
.set(PLUGINS, Seq(classOf[TestSparkPlugin].getName()))
137+
138+
sc = new SparkContext(conf)
139+
sc.parallelize(1 to 10, 2).count()
140+
141+
assert(TestSparkPlugin.executorPlugin.numOnTaskStart == 2)
142+
assert(TestSparkPlugin.executorPlugin.numOnTaskSucceeded == 2)
143+
assert(TestSparkPlugin.executorPlugin.numOnTaskFailed == 0)
144+
}
145+
146+
test("SPARK-33088: executor failed tasks trigger plugin calls") {
147+
val conf = new SparkConf()
148+
.setAppName(getClass().getName())
149+
.set(SparkLauncher.SPARK_MASTER, "local[1]")
150+
.set(PLUGINS, Seq(classOf[TestSparkPlugin].getName()))
151+
152+
sc = new SparkContext(conf)
153+
try {
154+
sc.parallelize(1 to 10, 2).foreach(i => throw new RuntimeException)
155+
} catch {
156+
case t: Throwable => // ignore exception
157+
}
158+
159+
eventually(timeout(10.seconds), interval(100.millis)) {
160+
assert(TestSparkPlugin.executorPlugin.numOnTaskStart == 2)
161+
assert(TestSparkPlugin.executorPlugin.numOnTaskSucceeded == 0)
162+
assert(TestSparkPlugin.executorPlugin.numOnTaskFailed == 2)
163+
}
164+
}
165+
132166
test("plugin initialization in non-local mode") {
133167
val path = Utils.createTempDir()
134168

@@ -309,13 +343,28 @@ private class TestDriverPlugin extends DriverPlugin {
309343

310344
private class TestExecutorPlugin extends ExecutorPlugin {
311345

346+
var numOnTaskStart: Int = 0
347+
var numOnTaskSucceeded: Int = 0
348+
var numOnTaskFailed: Int = 0
349+
312350
override def init(ctx: PluginContext, extraConf: JMap[String, String]): Unit = {
313351
ctx.metricRegistry().register("executorMetric", new Gauge[Int] {
314352
override def getValue(): Int = 84
315353
})
316354
TestSparkPlugin.executorContext = ctx
317355
}
318356

357+
override def onTaskStart(): Unit = {
358+
numOnTaskStart += 1
359+
}
360+
361+
override def onTaskSucceeded(): Unit = {
362+
numOnTaskSucceeded += 1
363+
}
364+
365+
override def onTaskFailed(failureReason: TaskFailedReason): Unit = {
366+
numOnTaskFailed += 1
367+
}
319368
}
320369

321370
private object TestSparkPlugin {

core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
7070
0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties,
7171
closureSerializer.serialize(TaskMetrics.registered).array())
7272
intercept[RuntimeException] {
73-
task.run(0, 0, null, null)
73+
task.run(0, 0, null, null, Option.empty)
7474
}
7575
assert(TaskContextSuite.completed)
7676
}
@@ -92,7 +92,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
9292
0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties,
9393
closureSerializer.serialize(TaskMetrics.registered).array())
9494
intercept[RuntimeException] {
95-
task.run(0, 0, null, null)
95+
task.run(0, 0, null, null, Option.empty)
9696
}
9797
assert(TaskContextSuite.lastError.getMessage == "damn error")
9898
}

0 commit comments

Comments
 (0)