Skip to content

Commit 73ce892

Browse files
authored
Add ExecutorPlugin methods called on Task Start/End (#713)
* wip * API with docs * move from Executor constructor to runtime * style * license * style * address comments
1 parent 6ce2020 commit 73ce892

File tree

4 files changed

+232
-7
lines changed

4 files changed

+232
-7
lines changed

core/src/main/java/org/apache/spark/ExecutorPlugin.java

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,38 @@ default void init() {}
5454
* will wait for the plugin to terminate before continuing its own shutdown.</p>
5555
*/
5656
default void shutdown() {}
57+
58+
/**
59+
* Perform any action before the task is run.
60+
*
61+
* <p>This method is invoked from the same thread the task will be executed.
62+
* Task-specific information can be accessed via {@link TaskContext#get}.</p>
63+
*
64+
* <p>Users should avoid expensive operations here, as this method will be called
65+
* on every task, and doing something expensive can significantly slow down a job.
66+
* It is not recommended for a user to call a remote service, for example.</p>
67+
*
68+
* <p>Exceptions thrown from this method do not propagate - they're caught,
69+
* logged, and suppressed. Therefore exceptions when executing this method won't
70+
* make the job fail.</p>
71+
*/
72+
default void onTaskStart() {}
73+
74+
/**
75+
* Perform an action after tasks completes without exceptions.
76+
*
77+
* <p>As {@link #onTaskStart() onTaskStart} exceptions are suppressed, this method
78+
* will still be invoked even if the corresponding {@link #onTaskStart} call for this
79+
* task failed.</p>
80+
*
81+
* <p>Same warnings of {@link #onTaskStart() onTaskStart} apply here.</p>
82+
*/
83+
default void onTaskSucceeded() {}
84+
85+
/**
86+
* Perform an action after tasks completes with exceptions.
87+
*
88+
* <p>Same warnings of {@link #onTaskStart() onTaskStart} apply here.</p>
89+
*/
90+
default void onTaskFailed(Throwable throwable) {}
5791
}

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ private[spark] class Executor(
216216
private[executor] def numRunningTasks: Int = runningTasks.size()
217217

218218
def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
219-
val tr = new TaskRunner(context, taskDescription)
219+
val tr = new TaskRunner(context, taskDescription, executorPlugins)
220220
runningTasks.put(taskDescription.taskId, tr)
221221
threadPool.execute(tr)
222222
}
@@ -292,7 +292,8 @@ private[spark] class Executor(
292292

293293
class TaskRunner(
294294
execBackend: ExecutorBackend,
295-
private val taskDescription: TaskDescription)
295+
private val taskDescription: TaskDescription,
296+
private val executorPlugins: Seq[ExecutorPlugin])
296297
extends Runnable {
297298

298299
val taskId = taskDescription.taskId
@@ -435,7 +436,8 @@ private[spark] class Executor(
435436
val res = task.run(
436437
taskAttemptId = taskId,
437438
attemptNumber = taskDescription.attemptNumber,
438-
metricsSystem = env.metricsSystem)
439+
metricsSystem = env.metricsSystem,
440+
executorPlugins = executorPlugins)
439441
threwException = false
440442
res
441443
} {

core/src/main/scala/org/apache/spark/scheduler/Task.scala

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,11 @@ package org.apache.spark.scheduler
2020
import java.nio.ByteBuffer
2121
import java.util.Properties
2222

23+
import com.palantir.logsafe.UnsafeArg
24+
2325
import org.apache.spark._
2426
import org.apache.spark.executor.TaskMetrics
27+
import org.apache.spark.internal.SafeLogging
2528
import org.apache.spark.internal.config.APP_CALLER_CONTEXT
2629
import org.apache.spark.memory.{MemoryMode, TaskMemoryManager}
2730
import org.apache.spark.metrics.MetricsSystem
@@ -63,7 +66,7 @@ private[spark] abstract class Task[T](
6366
val jobId: Option[Int] = None,
6467
val appId: Option[String] = None,
6568
val appAttemptId: Option[String] = None,
66-
val isBarrier: Boolean = false) extends Serializable {
69+
val isBarrier: Boolean = false) extends Serializable with SafeLogging {
6770

6871
@transient lazy val metrics: TaskMetrics =
6972
SparkEnv.get.closureSerializer.newInstance().deserialize(ByteBuffer.wrap(serializedTaskMetrics))
@@ -72,13 +75,15 @@ private[spark] abstract class Task[T](
7275
* Called by [[org.apache.spark.executor.Executor]] to run this task.
7376
*
7477
* @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext.
75-
* @param attemptNumber how many times this task has been attempted (0 for the first attempt)
78+
* @param attemptNumber how many times this task has been attempted (0 for the first attempt).
79+
* @param executorPlugins the plugins which will be notified of the run of this task.
7680
* @return the result of the task along with updates of Accumulators.
7781
*/
7882
final def run(
7983
taskAttemptId: Long,
8084
attemptNumber: Int,
81-
metricsSystem: MetricsSystem): T = {
85+
metricsSystem: MetricsSystem,
86+
executorPlugins: Seq[ExecutorPlugin] = Seq.empty): T = {
8287
SparkEnv.get.blockManager.registerTask(taskAttemptId)
8388
// TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether
8489
// the stage is barrier.
@@ -117,8 +122,12 @@ private[spark] abstract class Task[T](
117122
Option(taskAttemptId),
118123
Option(attemptNumber)).setCurrentContext()
119124

125+
sendTaskStartToPlugins(executorPlugins)
126+
120127
try {
121-
runTask(context)
128+
val taskResult = runTask(context)
129+
sendTaskSucceededToPlugins(executorPlugins)
130+
taskResult
122131
} catch {
123132
case e: Throwable =>
124133
// Catch all errors; run task failure callbacks, and rethrow the exception.
@@ -129,6 +138,7 @@ private[spark] abstract class Task[T](
129138
e.addSuppressed(t)
130139
}
131140
context.markTaskCompleted(Some(e))
141+
sendTaskFailedToPlugins(executorPlugins, e)
132142
throw e
133143
} finally {
134144
try {
@@ -159,6 +169,42 @@ private[spark] abstract class Task[T](
159169
}
160170
}
161171

172+
private def sendTaskStartToPlugins(executorPlugins: Seq[ExecutorPlugin]) {
173+
executorPlugins.foreach { plugin =>
174+
try {
175+
plugin.onTaskStart()
176+
} catch {
177+
case e: Exception =>
178+
safeLogWarning("Plugin onStart failed", e,
179+
UnsafeArg.of("pluginName", plugin.getClass().getCanonicalName()))
180+
}
181+
}
182+
}
183+
184+
private def sendTaskSucceededToPlugins(executorPlugins: Seq[ExecutorPlugin]) {
185+
executorPlugins.foreach { plugin =>
186+
try {
187+
plugin.onTaskSucceeded()
188+
} catch {
189+
case e: Exception =>
190+
safeLogWarning("Plugin onTaskSucceeded failed", e,
191+
UnsafeArg.of("pluginName", plugin.getClass().getCanonicalName()))
192+
}
193+
}
194+
}
195+
196+
private def sendTaskFailedToPlugins(executorPlugins: Seq[ExecutorPlugin], error: Throwable) {
197+
executorPlugins.foreach { plugin =>
198+
try {
199+
plugin.onTaskFailed(error)
200+
} catch {
201+
case e: Exception =>
202+
safeLogWarning("Plugin onTaskFailed failed", e,
203+
UnsafeArg.of("pluginName", plugin.getClass().getCanonicalName()))
204+
}
205+
}
206+
}
207+
162208
private var taskMemoryManager: TaskMemoryManager = _
163209

164210
def setTaskMemoryManager(taskMemoryManager: TaskMemoryManager): Unit = {
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark;
19+
20+
import com.google.common.collect.ImmutableList;
21+
import org.apache.spark.api.java.JavaRDD;
22+
import org.apache.spark.api.java.JavaSparkContext;
23+
import org.junit.After;
24+
import org.junit.Before;
25+
import org.junit.Test;
26+
27+
import static org.junit.Assert.*;
28+
29+
public class ExecutorPluginTaskSuite {
30+
private static final String EXECUTOR_PLUGIN_CONF_NAME = "spark.executor.plugins";
31+
private static final String taskWellBehavedPluginName = TestWellBehavedPlugin.class.getName();
32+
private static final String taskBadlyBehavedPluginName = TestBadlyBehavedPlugin.class.getName();
33+
34+
// Static value modified by testing plugins to ensure plugins are called correctly.
35+
public static int numOnTaskStart = 0;
36+
public static int numOnTaskSucceeded = 0;
37+
public static int numOnTaskFailed = 0;
38+
39+
private JavaSparkContext sc;
40+
41+
@Before
42+
public void setUp() {
43+
sc = null;
44+
numOnTaskStart = 0;
45+
numOnTaskSucceeded = 0;
46+
numOnTaskFailed = 0;
47+
}
48+
49+
@After
50+
public void tearDown() {
51+
if (sc != null) {
52+
sc.stop();
53+
sc = null;
54+
}
55+
}
56+
57+
private SparkConf initializeSparkConf(String pluginNames) {
58+
return new SparkConf()
59+
.setMaster("local")
60+
.setAppName("test")
61+
.set(EXECUTOR_PLUGIN_CONF_NAME, pluginNames);
62+
}
63+
64+
@Test
65+
public void testWellBehavedPlugin() {
66+
SparkConf conf = initializeSparkConf(taskWellBehavedPluginName);
67+
68+
sc = new JavaSparkContext(conf);
69+
JavaRDD<Integer> rdd = sc.parallelize(ImmutableList.of(1, 2));
70+
rdd.filter(value -> value.equals(1)).collect();
71+
72+
assertEquals(numOnTaskStart, 1);
73+
assertEquals(numOnTaskSucceeded, 1);
74+
assertEquals(numOnTaskFailed, 0);
75+
}
76+
77+
@Test
78+
public void testBadlyBehavedPluginDoesNotAffectWellBehavedPlugin() {
79+
SparkConf conf = initializeSparkConf(
80+
taskWellBehavedPluginName + "," + taskBadlyBehavedPluginName);
81+
82+
sc = new JavaSparkContext(conf);
83+
JavaRDD<Integer> rdd = sc.parallelize(ImmutableList.of(1, 2));
84+
rdd.filter(value -> value.equals(1)).collect();
85+
86+
assertEquals(numOnTaskStart, 1);
87+
assertEquals(numOnTaskSucceeded, 2);
88+
assertEquals(numOnTaskFailed, 0);
89+
}
90+
91+
@Test
92+
public void testTaskWhichFails() {
93+
SparkConf conf = initializeSparkConf(taskWellBehavedPluginName);
94+
95+
sc = new JavaSparkContext(conf);
96+
JavaRDD<Integer> rdd = sc.parallelize(ImmutableList.of(1, 2));
97+
try {
98+
rdd.foreach(integer -> {
99+
throw new RuntimeException();
100+
});
101+
} catch (Exception e) {
102+
// ignore exception
103+
}
104+
105+
assertEquals(numOnTaskStart, 1);
106+
assertEquals(numOnTaskSucceeded, 0);
107+
assertEquals(numOnTaskFailed, 1);
108+
}
109+
110+
public static class TestWellBehavedPlugin implements ExecutorPlugin {
111+
@Override
112+
public void onTaskStart() {
113+
numOnTaskStart++;
114+
}
115+
116+
@Override
117+
public void onTaskSucceeded() {
118+
numOnTaskSucceeded++;
119+
}
120+
121+
@Override
122+
public void onTaskFailed(Throwable throwable) {
123+
numOnTaskFailed++;
124+
}
125+
}
126+
127+
public static class TestBadlyBehavedPlugin implements ExecutorPlugin {
128+
@Override
129+
public void onTaskStart() {
130+
throw new RuntimeException();
131+
}
132+
133+
@Override
134+
public void onTaskSucceeded() {
135+
numOnTaskSucceeded++;
136+
}
137+
138+
@Override
139+
public void onTaskFailed(Throwable throwable) {
140+
numOnTaskFailed++;
141+
}
142+
}
143+
}

0 commit comments

Comments
 (0)