Skip to content
15 changes: 12 additions & 3 deletions core/src/main/scala/org/apache/spark/TestUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,20 @@ private[spark] object TestUtils {
destDir: File,
toStringValue: String = "",
baseClass: String = null,
classpathUrls: Seq[URL] = Seq.empty): File = {
classpathUrls: Seq[URL] = Seq.empty,
implementsClasses: Seq[String] = Seq.empty,
extraCodeBody: String = ""): File = {
val extendsText = Option(baseClass).map { c => s" extends ${c}" }.getOrElse("")
val implementsText =
"implements " + (implementsClasses :+ "java.io.Serializable").mkString(", ")
val sourceFile = new JavaSourceFromString(className,
"public class " + className + extendsText + " implements java.io.Serializable {" +
" @Override public String toString() { return \"" + toStringValue + "\"; }}")
s"""
|public class $className $extendsText $implementsText {
| @Override public String toString() { return "$toStringValue"; }
|
| $extraCodeBody
|}
""".stripMargin)
createCompiledClass(className, destDir, sourceFile, classpathUrls)
}

Expand Down
12 changes: 7 additions & 5 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,6 @@ private[spark] class Executor(
// for fetching remote cached RDD blocks, so need to make sure it uses the right classloader too.
env.serializerManager.setDefaultClassLoader(replClassLoader)

// Plugins need to load using a class loader that includes the executor's user classpath
private val plugins: Option[PluginContainer] = Utils.withContextClassLoader(replClassLoader) {
PluginContainer(env, resources.asJava)
}

// Max size of direct result. If task result is bigger than this, we use the block manager
// to send the result back.
private val maxDirectResultSize = Math.min(
Expand Down Expand Up @@ -225,6 +220,13 @@ private[spark] class Executor(

heartbeater.start()

// Plugins need to load using a class loader that includes the executor's user classpath.
// Plugins also needs to be initialized after the heartbeater started
// to avoid blocking to send heartbeat (see SPARK-32175).
private val plugins: Option[PluginContainer] = Utils.withContextClassLoader(replClassLoader) {
PluginContainer(env, resources.asJava)
}

metricsPoller.start()

private[executor] def numRunningTasks: Int = runningTasks.size()
Expand Down
72 changes: 70 additions & 2 deletions core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.executor

import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.io.{Externalizable, File, ObjectInput, ObjectOutput}
import java.lang.Thread.UncaughtExceptionHandler
import java.nio.ByteBuffer
import java.util.Properties
Expand All @@ -41,6 +41,7 @@ import org.scalatestplus.mockito.MockitoSugar
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.{SimpleApplicationTest, SparkSubmitSuite}
import org.apache.spark.internal.config._
import org.apache.spark.internal.config.UI._
import org.apache.spark.memory.TestMemoryManager
Expand All @@ -52,7 +53,7 @@ import org.apache.spark.scheduler.{DirectTaskResult, FakeTask, ResultTask, Task,
import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{BlockManager, BlockManagerId}
import org.apache.spark.util.{LongAccumulator, UninterruptibleThread}
import org.apache.spark.util.{LongAccumulator, UninterruptibleThread, Utils}

class ExecutorSuite extends SparkFunSuite
with LocalSparkContext with MockitoSugar with Eventually with PrivateMethodTester {
Expand Down Expand Up @@ -402,6 +403,73 @@ class ExecutorSuite extends SparkFunSuite
assert(taskMetrics.getMetricValue("JVMHeapMemory") > 0)
}

test("SPARK-32175: Plugin initialization should start after heartbeater started") {
withTempDir { tempDir =>
val sparkPluginCodeBody =
"""
|@Override
|public org.apache.spark.api.plugin.ExecutorPlugin executorPlugin() {
| return new TestExecutorPlugin();
|}
|
|@Override
|public org.apache.spark.api.plugin.DriverPlugin driverPlugin() { return null; }
""".stripMargin
val executorPluginBody =
"""
|@Override
|public void init(
| org.apache.spark.api.plugin.PluginContext ctx,
| java.util.Map<String, String> extraConf) {
| try {
| Thread.sleep(8 * 1000);
| } catch (InterruptedException e) {
| throw new RuntimeException(e);
| }
|}
""".stripMargin

val compiledExecutorPlugin = TestUtils.createCompiledClass(
"TestExecutorPlugin",
tempDir,
"",
null,
Seq.empty,
Seq("org.apache.spark.api.plugin.ExecutorPlugin"),
executorPluginBody)

val thisClassPath =
sys.props("java.class.path").split(File.pathSeparator).map(p => new File(p).toURI.toURL)
val compiledSparkPlugin = TestUtils.createCompiledClass(
"TestSparkPlugin",
tempDir,
"",
null,
Seq(tempDir.toURI.toURL) ++ thisClassPath,
Seq("org.apache.spark.api.plugin.SparkPlugin"),
sparkPluginCodeBody)

val jarUrl = TestUtils.createJar(
Seq(compiledSparkPlugin, compiledExecutorPlugin),
new File(tempDir, "testPlugin.jar"))

val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
val args = Seq(
"--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"),
"--name", "testApp",
"--master", "local-cluster[1,1,1024]",
"--conf", "spark.plugins=TestSparkPlugin",
"--conf", "spark.storage.blockManagerSlaveTimeoutMs=" + 5 * 1000,
"--conf", "spark.network.timeoutInterval=" + 1000,
"--conf", "spark.executor.heartbeatInterval=" + 1000,
"--conf", "spark.executor.extraClassPath=" + jarUrl.toString,
"--conf", "spark.driver.extraClassPath=" + jarUrl.toString,
"--conf", "spark.ui.enabled=false",
unusedJar.toString)
SparkSubmitSuite.runSparkSubmit(args, timeout = 30.seconds)
}
}

private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = {
val mockEnv = mock[SparkEnv]
val mockRpcEnv = mock[RpcEnv]
Expand Down