Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions core/src/main/scala/org/apache/spark/JobArtifactSet.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark

import java.io.Serializable

/**
* Artifact set for a job.
* This class is used to store session (i.e `SparkSession`) specific resources/artifacts.
*
* When Spark Connect is used, this job-set points towards session-specific jars and class files.
* Note that Spark Connect is not a requirement for using this class.
*
* @param uuid An optional UUID for this session. If unset, a default session will be used.
* @param replClassDirUri An optional custom URI to point towards class files.
* @param jars Jars belonging to this session.
* @param files Files belonging to this session.
* @param archives Archives belonging to this session.
*/
class JobArtifactSet(
val uuid: Option[String],
val replClassDirUri: Option[String],
val jars: Map[String, Long],
val files: Map[String, Long],
val archives: Map[String, Long]) extends Serializable {
def withActive[T](f: => T): T = JobArtifactSet.withActive(this)(f)

override def hashCode(): Int = {
Seq(uuid, replClassDirUri, jars.toSeq, files.toSeq, archives.toSeq).hashCode()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Objects.hash(...) is bit simpler.

}

override def equals(obj: Any): Boolean = {
obj match {
case that: JobArtifactSet =>
this.getClass == that.getClass && this.uuid == that.uuid &&
this.replClassDirUri == that.replClassDirUri && this.jars.toSeq == that.jars.toSeq &&
this.files.toSeq == that.files.toSeq && this.archives.toSeq == that.archives.toSeq
}
}

}

object JobArtifactSet {

private[this] val current = new ThreadLocal[Option[JobArtifactSet]] {
override def initialValue(): Option[JobArtifactSet] = None
}

/**
* When Spark Connect isn't used, we default back to the shared resources.
* @param sc The active [[SparkContext]]
* @return A [[JobArtifactSet]] containing a copy of the jars/files/archives from the underlying
* [[SparkContext]] `sc`.
*/
def apply(sc: SparkContext): JobArtifactSet = {
new JobArtifactSet(
uuid = None,
replClassDirUri = sc.conf.getOption("spark.repl.class.uri"),
jars = sc.addedJars.toMap,
files = sc.addedFiles.toMap,
archives = sc.addedArchives.toMap)
}

/**
* Empty artifact set for use in tests.
*/
private[spark] def apply(): JobArtifactSet = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just create a single empty one?

new JobArtifactSet(
None,
None,
Map.empty,
Map.empty,
Map.empty)
}

/**
* Used for testing. Returns artifacts from [[SparkContext]] if one exists or otherwise, an
* empty set.
*/
private[spark] def defaultArtifactSet(): JobArtifactSet = {
SparkContext.getActive.map(sc => JobArtifactSet(sc)).getOrElse(JobArtifactSet())
}

/**
* Execute a block of code with the currently active [[JobArtifactSet]].
* @param active
* @param block
* @tparam T
*/
def withActive[T](active: JobArtifactSet)(block: => T): T = {
val old = current.get()
current.set(Option(active))
try block finally {
current.set(old)
}
}

/**
* Optionally returns the active [[JobArtifactSet]].
*/
def active: Option[JobArtifactSet] = current.get()

/**
* Return the active [[JobArtifactSet]] or creates the default set using the [[SparkContext]].
* @param sc
*/
def getActiveOrDefault(sc: SparkContext): JobArtifactSet = active.getOrElse(JobArtifactSet(sc))
}
120 changes: 79 additions & 41 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ import javax.ws.rs.core.UriBuilder

import scala.collection.JavaConverters._
import scala.collection.immutable
import scala.collection.mutable.{ArrayBuffer, HashMap, Map, WrappedArray}
import scala.collection.mutable.{ArrayBuffer, HashMap, WrappedArray}
import scala.concurrent.duration._
import scala.util.control.NonFatal

import com.google.common.cache.CacheBuilder
import com.google.common.util.concurrent.ThreadFactoryBuilder
import org.slf4j.MDC

Expand All @@ -53,6 +54,14 @@ import org.apache.spark.shuffle.{FetchFailedException, ShuffleBlockPusher}
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.util._

private[spark] class IsolatedSessionState(
val sessionUUID: String,
val urlClassLoader: MutableURLClassLoader,
var replClassLoader: ClassLoader,
val currentFiles: HashMap[String, Long],
val currentJars: HashMap[String, Long],
val currentArchives: HashMap[String, Long])

/**
* Spark executor, backed by a threadpool to run tasks.
*
Expand All @@ -76,11 +85,6 @@ private[spark] class Executor(
val stopHookReference = ShutdownHookManager.addShutdownHook(
() => stop()
)
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
// Each map holds the master's timestamp for the version of that file or JAR we got.
private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
private val currentJars: HashMap[String, Long] = new HashMap[String, Long]()
private val currentArchives: HashMap[String, Long] = new HashMap[String, Long]()

private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))

Expand Down Expand Up @@ -160,16 +164,34 @@ private[spark] class Executor(

private val killOnFatalErrorDepth = conf.get(EXECUTOR_KILL_ON_FATAL_ERROR_DEPTH)

// Create our ClassLoader
// do this after SparkEnv creation so can access the SecurityManager
private val urlClassLoader = createClassLoader()
private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
private val systemLoader = Utils.getContextOrSparkClassLoader

private def newSessionState(
sessionUUID: String,
classUri: Option[String]): IsolatedSessionState = {
val currentFiles = new HashMap[String, Long]
val currentJars = new HashMap[String, Long]
val currentArchives = new HashMap[String, Long]
val urlClassLoader = createClassLoader(currentJars)
val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader, classUri)
new IsolatedSessionState(
sessionUUID, urlClassLoader, replClassLoader, currentFiles, currentJars, currentArchives)
}

// Classloader isolation
// The default isolation group
val defaultSessionState = newSessionState("default", None)

val isolatedSessionCache = CacheBuilder.newBuilder()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is another problem here. If you cache Spark session, and say, it's evicted, then it will create a new session state with empty file lists. In this case, the Executor will try to download them again, and overwrite all the files. This behaviour is disallowed by default in SparkContext.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spark.files.overwrite is false by default .. so the tasks will fail actually.

.maximumSize(100)
.expireAfterAccess(5, TimeUnit.MINUTES)
.build[String, IsolatedSessionState]

// Set the classloader for serializer
env.serializer.setDefaultClassLoader(replClassLoader)
env.serializer.setDefaultClassLoader(defaultSessionState.replClassLoader)
// SPARK-21928. SerializerManager's internal instance of Kryo might get used in netty threads
// for fetching remote cached RDD blocks, so need to make sure it uses the right classloader too.
env.serializerManager.setDefaultClassLoader(replClassLoader)
env.serializerManager.setDefaultClassLoader(defaultSessionState.replClassLoader)

// Max size of direct result. If task result is bigger than this, we use the block manager
// to send the result back. This is guaranteed to be smaller than array bytes limit (2GB)
Expand Down Expand Up @@ -273,17 +295,18 @@ private[spark] class Executor(
private val Seq(initialUserJars, initialUserFiles, initialUserArchives) =
Seq("jar", "file", "archive").map { key =>
conf.getOption(s"spark.app.initial.$key.urls").map { urls =>
Map(urls.split(",").map(url => (url, appStartTime)): _*)
}.getOrElse(Map.empty)
immutable.Map(urls.split(",").map(url => (url, appStartTime)): _*)
}.getOrElse(immutable.Map.empty)
}
updateDependencies(initialUserFiles, initialUserJars, initialUserArchives)
updateDependencies(initialUserFiles, initialUserJars, initialUserArchives, defaultSessionState)

// Plugins need to load using a class loader that includes the executor's user classpath.
// Plugins also needs to be initialized after the heartbeater started
// to avoid blocking to send heartbeat (see SPARK-32175).
private val plugins: Option[PluginContainer] = Utils.withContextClassLoader(replClassLoader) {
PluginContainer(env, resources.asJava)
}
private val plugins: Option[PluginContainer] =
Utils.withContextClassLoader(defaultSessionState.replClassLoader) {
PluginContainer(env, resources.asJava)
}

metricsPoller.start()

Expand Down Expand Up @@ -381,9 +404,9 @@ private[spark] class Executor(
if (killMarkCleanupService != null) {
killMarkCleanupService.shutdown()
}
if (replClassLoader != null && plugins != null) {
if (defaultSessionState != null && plugins != null) {
// Notify plugins that executor is shutting down so they can terminate cleanly
Utils.withContextClassLoader(replClassLoader) {
Utils.withContextClassLoader(defaultSessionState.replClassLoader) {
plugins.foreach(_.shutdown())
}
}
Expand Down Expand Up @@ -485,6 +508,16 @@ private[spark] class Executor(
}

override def run(): Unit = {

// Classloader isolation
val isolatedSessionUUID: Option[String] = taskDescription.artifacts.uuid
val isolatedSession = isolatedSessionUUID match {
case Some(uuid) => isolatedSessionCache.get(
uuid,
() => newSessionState(uuid, taskDescription.artifacts.replClassDirUri))
case _ => defaultSessionState
}

setMDCForTask(taskName, mdcProperties)
threadId = Thread.currentThread.getId
Thread.currentThread.setName(threadName)
Expand All @@ -494,7 +527,7 @@ private[spark] class Executor(
val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
Thread.currentThread.setContextClassLoader(replClassLoader)
Thread.currentThread.setContextClassLoader(isolatedSession.replClassLoader)
val ser = env.closureSerializer.newInstance()
logInfo(s"Running $taskName")
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
Expand All @@ -509,7 +542,10 @@ private[spark] class Executor(
Executor.taskDeserializationProps.set(taskDescription.properties)

updateDependencies(
taskDescription.addedFiles, taskDescription.addedJars, taskDescription.addedArchives)
taskDescription.artifacts.files,
taskDescription.artifacts.jars,
taskDescription.artifacts.archives,
isolatedSession)
task = ser.deserialize[Task[Any]](
taskDescription.serializedTask, Thread.currentThread.getContextClassLoader)
task.localProperties = taskDescription.properties
Expand Down Expand Up @@ -961,15 +997,13 @@ private[spark] class Executor(
* Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes
* created by the interpreter to the search path
*/
private def createClassLoader(): MutableURLClassLoader = {
private def createClassLoader(currentJars: HashMap[String, Long]): MutableURLClassLoader = {
// Bootstrap the list of jars with the user class path.
val now = System.currentTimeMillis()
userClassPath.foreach { url =>
currentJars(url.getPath().split("/").last) = now
}

val currentLoader = Utils.getContextOrSparkClassLoader

// For each of the jars in the jarSet, add them to the class loader.
// We assume each of the files has already been fetched.
val urls = userClassPath.toArray ++ currentJars.keySet.map { uri =>
Expand All @@ -978,18 +1012,20 @@ private[spark] class Executor(
logInfo(s"Starting executor with user classpath (userClassPathFirst = $userClassPathFirst): " +
urls.mkString("'", ",", "'"))
if (userClassPathFirst) {
new ChildFirstURLClassLoader(urls, currentLoader)
new ChildFirstURLClassLoader(urls, systemLoader)
} else {
new MutableURLClassLoader(urls, currentLoader)
new MutableURLClassLoader(urls, systemLoader)
}
}

/**
* If the REPL is in use, add another ClassLoader that will read
* new classes defined by the REPL as the user types code
*/
private def addReplClassLoaderIfNeeded(parent: ClassLoader): ClassLoader = {
val classUri = conf.get("spark.repl.class.uri", null)
private def addReplClassLoaderIfNeeded(
parent: ClassLoader,
sessionClassUri: Option[String]): ClassLoader = {
val classUri = sessionClassUri.getOrElse(conf.get("spark.repl.class.uri", null))
if (classUri != null) {
logInfo("Using REPL class URI: " + classUri)
new ExecutorClassLoader(conf, env, classUri, parent, userClassPathFirst)
Expand All @@ -1004,9 +1040,10 @@ private[spark] class Executor(
* Visible for testing.
*/
private[executor] def updateDependencies(
newFiles: Map[String, Long],
newJars: Map[String, Long],
newArchives: Map[String, Long],
newFiles: immutable.Map[String, Long],
newJars: immutable.Map[String, Long],
newArchives: immutable.Map[String, Long],
state: IsolatedSessionState,
testStartLatch: Option[CountDownLatch] = None,
testEndLatch: Option[CountDownLatch] = None): Unit = {
lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
Expand All @@ -1015,14 +1052,15 @@ private[spark] class Executor(
// For testing, so we can simulate a slow file download:
testStartLatch.foreach(_.countDown())
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
for ((name, timestamp) <- newFiles if state.currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo(s"Fetching $name with timestamp $timestamp")
// Fetch file with useCache mode, close cache for local mode.
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf,
hadoopConf, timestamp, useCache = !isLocal)
currentFiles(name) = timestamp
state.currentFiles(name) = timestamp
}
for ((name, timestamp) <- newArchives if currentArchives.getOrElse(name, -1L) < timestamp) {
for ((name, timestamp) <- newArchives if
state.currentArchives.getOrElse(name, -1L) < timestamp) {
logInfo(s"Fetching $name with timestamp $timestamp")
val sourceURI = new URI(name)
val uriToDownload = UriBuilder.fromUri(sourceURI).fragment(null).build()
Expand All @@ -1035,24 +1073,24 @@ private[spark] class Executor(
s"Unpacking an archive $name from ${source.getAbsolutePath} to ${dest.getAbsolutePath}")
Utils.deleteRecursively(dest)
Utils.unpack(source, dest)
currentArchives(name) = timestamp
state.currentArchives(name) = timestamp
}
for ((name, timestamp) <- newJars) {
val localName = new URI(name).getPath.split("/").last
val currentTimeStamp = currentJars.get(name)
.orElse(currentJars.get(localName))
val currentTimeStamp = state.currentJars.get(name)
.orElse(state.currentJars.get(localName))
.getOrElse(-1L)
if (currentTimeStamp < timestamp) {
logInfo(s"Fetching $name with timestamp $timestamp")
// Fetch file with useCache mode, close cache for local mode.
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf,
hadoopConf, timestamp, useCache = !isLocal)
currentJars(name) = timestamp
state.currentJars(name) = timestamp
// Add it to our class loader
val url = new File(SparkFiles.getRootDirectory(), localName).toURI.toURL
if (!urlClassLoader.getURLs().contains(url)) {
if (!state.urlClassLoader.getURLs().contains(url)) {
logInfo(s"Adding $url to class loader")
urlClassLoader.addURL(url)
state.urlClassLoader.addURL(url)
}
}
}
Expand Down
Loading