-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-44078][CONNECT][CORE] Add support for classloader/resource isolation #41625
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,123 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark | ||
|
|
||
| import java.io.Serializable | ||
|
|
||
| /** | ||
| * Artifact set for a job. | ||
| * This class is used to store session (i.e `SparkSession`) specific resources/artifacts. | ||
| * | ||
| * When Spark Connect is used, this job-set points towards session-specific jars and class files. | ||
| * Note that Spark Connect is not a requirement for using this class. | ||
| * | ||
| * @param uuid An optional UUID for this session. If unset, a default session will be used. | ||
| * @param replClassDirUri An optional custom URI to point towards class files. | ||
| * @param jars Jars belonging to this session. | ||
| * @param files Files belonging to this session. | ||
| * @param archives Archives belonging to this session. | ||
| */ | ||
| class JobArtifactSet( | ||
| val uuid: Option[String], | ||
| val replClassDirUri: Option[String], | ||
| val jars: Map[String, Long], | ||
| val files: Map[String, Long], | ||
| val archives: Map[String, Long]) extends Serializable { | ||
| def withActive[T](f: => T): T = JobArtifactSet.withActive(this)(f) | ||
|
|
||
| override def hashCode(): Int = { | ||
| Seq(uuid, replClassDirUri, jars.toSeq, files.toSeq, archives.toSeq).hashCode() | ||
| } | ||
|
|
||
| override def equals(obj: Any): Boolean = { | ||
| obj match { | ||
| case that: JobArtifactSet => | ||
| this.getClass == that.getClass && this.uuid == that.uuid && | ||
| this.replClassDirUri == that.replClassDirUri && this.jars.toSeq == that.jars.toSeq && | ||
| this.files.toSeq == that.files.toSeq && this.archives.toSeq == that.archives.toSeq | ||
| } | ||
| } | ||
|
|
||
| } | ||
|
|
||
| object JobArtifactSet { | ||
|
|
||
| private[this] val current = new ThreadLocal[Option[JobArtifactSet]] { | ||
| override def initialValue(): Option[JobArtifactSet] = None | ||
| } | ||
|
|
||
| /** | ||
| * When Spark Connect isn't used, we default back to the shared resources. | ||
| * @param sc The active [[SparkContext]] | ||
| * @return A [[JobArtifactSet]] containing a copy of the jars/files/archives from the underlying | ||
| * [[SparkContext]] `sc`. | ||
| */ | ||
| def apply(sc: SparkContext): JobArtifactSet = { | ||
| new JobArtifactSet( | ||
| uuid = None, | ||
| replClassDirUri = sc.conf.getOption("spark.repl.class.uri"), | ||
| jars = sc.addedJars.toMap, | ||
| files = sc.addedFiles.toMap, | ||
| archives = sc.addedArchives.toMap) | ||
| } | ||
|
|
||
| /** | ||
| * Empty artifact set for use in tests. | ||
| */ | ||
| private[spark] def apply(): JobArtifactSet = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just create a single empty one? |
||
| new JobArtifactSet( | ||
| None, | ||
| None, | ||
| Map.empty, | ||
| Map.empty, | ||
| Map.empty) | ||
| } | ||
|
|
||
| /** | ||
| * Used for testing. Returns artifacts from [[SparkContext]] if one exists or otherwise, an | ||
| * empty set. | ||
| */ | ||
| private[spark] def defaultArtifactSet(): JobArtifactSet = { | ||
| SparkContext.getActive.map(sc => JobArtifactSet(sc)).getOrElse(JobArtifactSet()) | ||
| } | ||
|
|
||
| /** | ||
| * Execute a block of code with the currently active [[JobArtifactSet]]. | ||
| * @param active | ||
| * @param block | ||
| * @tparam T | ||
| */ | ||
| def withActive[T](active: JobArtifactSet)(block: => T): T = { | ||
| val old = current.get() | ||
| current.set(Option(active)) | ||
| try block finally { | ||
| current.set(old) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Optionally returns the active [[JobArtifactSet]]. | ||
| */ | ||
| def active: Option[JobArtifactSet] = current.get() | ||
|
|
||
| /** | ||
| * Return the active [[JobArtifactSet]] or creates the default set using the [[SparkContext]]. | ||
| * @param sc | ||
| */ | ||
| def getActiveOrDefault(sc: SparkContext): JobArtifactSet = active.getOrElse(JobArtifactSet(sc)) | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,10 +31,11 @@ import javax.ws.rs.core.UriBuilder | |
|
|
||
| import scala.collection.JavaConverters._ | ||
| import scala.collection.immutable | ||
| import scala.collection.mutable.{ArrayBuffer, HashMap, Map, WrappedArray} | ||
| import scala.collection.mutable.{ArrayBuffer, HashMap, WrappedArray} | ||
| import scala.concurrent.duration._ | ||
| import scala.util.control.NonFatal | ||
|
|
||
| import com.google.common.cache.CacheBuilder | ||
| import com.google.common.util.concurrent.ThreadFactoryBuilder | ||
| import org.slf4j.MDC | ||
|
|
||
|
|
@@ -53,6 +54,14 @@ import org.apache.spark.shuffle.{FetchFailedException, ShuffleBlockPusher} | |
| import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} | ||
| import org.apache.spark.util._ | ||
|
|
||
| private[spark] class IsolatedSessionState( | ||
| val sessionUUID: String, | ||
| val urlClassLoader: MutableURLClassLoader, | ||
| var replClassLoader: ClassLoader, | ||
| val currentFiles: HashMap[String, Long], | ||
| val currentJars: HashMap[String, Long], | ||
| val currentArchives: HashMap[String, Long]) | ||
|
|
||
| /** | ||
| * Spark executor, backed by a threadpool to run tasks. | ||
| * | ||
|
|
@@ -76,11 +85,6 @@ private[spark] class Executor( | |
| val stopHookReference = ShutdownHookManager.addShutdownHook( | ||
| () => stop() | ||
| ) | ||
| // Application dependencies (added through SparkContext) that we've fetched so far on this node. | ||
| // Each map holds the master's timestamp for the version of that file or JAR we got. | ||
| private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]() | ||
| private val currentJars: HashMap[String, Long] = new HashMap[String, Long]() | ||
| private val currentArchives: HashMap[String, Long] = new HashMap[String, Long]() | ||
|
|
||
| private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0)) | ||
|
|
||
|
|
@@ -160,16 +164,34 @@ private[spark] class Executor( | |
|
|
||
| private val killOnFatalErrorDepth = conf.get(EXECUTOR_KILL_ON_FATAL_ERROR_DEPTH) | ||
|
|
||
| // Create our ClassLoader | ||
| // do this after SparkEnv creation so can access the SecurityManager | ||
| private val urlClassLoader = createClassLoader() | ||
| private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) | ||
| private val systemLoader = Utils.getContextOrSparkClassLoader | ||
|
|
||
| private def newSessionState( | ||
| sessionUUID: String, | ||
| classUri: Option[String]): IsolatedSessionState = { | ||
| val currentFiles = new HashMap[String, Long] | ||
| val currentJars = new HashMap[String, Long] | ||
| val currentArchives = new HashMap[String, Long] | ||
| val urlClassLoader = createClassLoader(currentJars) | ||
| val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader, classUri) | ||
| new IsolatedSessionState( | ||
| sessionUUID, urlClassLoader, replClassLoader, currentFiles, currentJars, currentArchives) | ||
| } | ||
|
|
||
| // Classloader isolation | ||
| // The default isolation group | ||
| val defaultSessionState = newSessionState("default", None) | ||
|
|
||
| val isolatedSessionCache = CacheBuilder.newBuilder() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is another problem here. If you cache Spark session, and say, it's evicted, then it will create a new session state with empty file lists. In this case, the
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| .maximumSize(100) | ||
| .expireAfterAccess(5, TimeUnit.MINUTES) | ||
| .build[String, IsolatedSessionState] | ||
|
|
||
| // Set the classloader for serializer | ||
| env.serializer.setDefaultClassLoader(replClassLoader) | ||
| env.serializer.setDefaultClassLoader(defaultSessionState.replClassLoader) | ||
| // SPARK-21928. SerializerManager's internal instance of Kryo might get used in netty threads | ||
| // for fetching remote cached RDD blocks, so need to make sure it uses the right classloader too. | ||
| env.serializerManager.setDefaultClassLoader(replClassLoader) | ||
| env.serializerManager.setDefaultClassLoader(defaultSessionState.replClassLoader) | ||
|
|
||
| // Max size of direct result. If task result is bigger than this, we use the block manager | ||
| // to send the result back. This is guaranteed to be smaller than array bytes limit (2GB) | ||
|
|
@@ -273,17 +295,18 @@ private[spark] class Executor( | |
| private val Seq(initialUserJars, initialUserFiles, initialUserArchives) = | ||
| Seq("jar", "file", "archive").map { key => | ||
| conf.getOption(s"spark.app.initial.$key.urls").map { urls => | ||
| Map(urls.split(",").map(url => (url, appStartTime)): _*) | ||
| }.getOrElse(Map.empty) | ||
| immutable.Map(urls.split(",").map(url => (url, appStartTime)): _*) | ||
| }.getOrElse(immutable.Map.empty) | ||
| } | ||
| updateDependencies(initialUserFiles, initialUserJars, initialUserArchives) | ||
| updateDependencies(initialUserFiles, initialUserJars, initialUserArchives, defaultSessionState) | ||
|
|
||
| // Plugins need to load using a class loader that includes the executor's user classpath. | ||
| // Plugins also needs to be initialized after the heartbeater started | ||
| // to avoid blocking to send heartbeat (see SPARK-32175). | ||
| private val plugins: Option[PluginContainer] = Utils.withContextClassLoader(replClassLoader) { | ||
| PluginContainer(env, resources.asJava) | ||
| } | ||
| private val plugins: Option[PluginContainer] = | ||
| Utils.withContextClassLoader(defaultSessionState.replClassLoader) { | ||
| PluginContainer(env, resources.asJava) | ||
| } | ||
|
|
||
| metricsPoller.start() | ||
|
|
||
|
|
@@ -381,9 +404,9 @@ private[spark] class Executor( | |
| if (killMarkCleanupService != null) { | ||
| killMarkCleanupService.shutdown() | ||
| } | ||
| if (replClassLoader != null && plugins != null) { | ||
| if (defaultSessionState != null && plugins != null) { | ||
| // Notify plugins that executor is shutting down so they can terminate cleanly | ||
| Utils.withContextClassLoader(replClassLoader) { | ||
| Utils.withContextClassLoader(defaultSessionState.replClassLoader) { | ||
| plugins.foreach(_.shutdown()) | ||
| } | ||
| } | ||
|
|
@@ -485,6 +508,16 @@ private[spark] class Executor( | |
| } | ||
|
|
||
| override def run(): Unit = { | ||
|
|
||
| // Classloader isolation | ||
| val isolatedSessionUUID: Option[String] = taskDescription.artifacts.uuid | ||
| val isolatedSession = isolatedSessionUUID match { | ||
| case Some(uuid) => isolatedSessionCache.get( | ||
| uuid, | ||
| () => newSessionState(uuid, taskDescription.artifacts.replClassDirUri)) | ||
| case _ => defaultSessionState | ||
| } | ||
|
|
||
| setMDCForTask(taskName, mdcProperties) | ||
| threadId = Thread.currentThread.getId | ||
| Thread.currentThread.setName(threadName) | ||
|
|
@@ -494,7 +527,7 @@ private[spark] class Executor( | |
| val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) { | ||
| threadMXBean.getCurrentThreadCpuTime | ||
| } else 0L | ||
| Thread.currentThread.setContextClassLoader(replClassLoader) | ||
| Thread.currentThread.setContextClassLoader(isolatedSession.replClassLoader) | ||
| val ser = env.closureSerializer.newInstance() | ||
| logInfo(s"Running $taskName") | ||
| execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER) | ||
|
|
@@ -509,7 +542,10 @@ private[spark] class Executor( | |
| Executor.taskDeserializationProps.set(taskDescription.properties) | ||
|
|
||
| updateDependencies( | ||
| taskDescription.addedFiles, taskDescription.addedJars, taskDescription.addedArchives) | ||
| taskDescription.artifacts.files, | ||
| taskDescription.artifacts.jars, | ||
| taskDescription.artifacts.archives, | ||
| isolatedSession) | ||
| task = ser.deserialize[Task[Any]]( | ||
| taskDescription.serializedTask, Thread.currentThread.getContextClassLoader) | ||
| task.localProperties = taskDescription.properties | ||
|
|
@@ -961,15 +997,13 @@ private[spark] class Executor( | |
| * Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes | ||
| * created by the interpreter to the search path | ||
| */ | ||
| private def createClassLoader(): MutableURLClassLoader = { | ||
| private def createClassLoader(currentJars: HashMap[String, Long]): MutableURLClassLoader = { | ||
| // Bootstrap the list of jars with the user class path. | ||
| val now = System.currentTimeMillis() | ||
| userClassPath.foreach { url => | ||
| currentJars(url.getPath().split("/").last) = now | ||
| } | ||
|
|
||
| val currentLoader = Utils.getContextOrSparkClassLoader | ||
|
|
||
| // For each of the jars in the jarSet, add them to the class loader. | ||
| // We assume each of the files has already been fetched. | ||
| val urls = userClassPath.toArray ++ currentJars.keySet.map { uri => | ||
|
|
@@ -978,18 +1012,20 @@ private[spark] class Executor( | |
| logInfo(s"Starting executor with user classpath (userClassPathFirst = $userClassPathFirst): " + | ||
| urls.mkString("'", ",", "'")) | ||
| if (userClassPathFirst) { | ||
| new ChildFirstURLClassLoader(urls, currentLoader) | ||
| new ChildFirstURLClassLoader(urls, systemLoader) | ||
| } else { | ||
| new MutableURLClassLoader(urls, currentLoader) | ||
| new MutableURLClassLoader(urls, systemLoader) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * If the REPL is in use, add another ClassLoader that will read | ||
| * new classes defined by the REPL as the user types code | ||
| */ | ||
| private def addReplClassLoaderIfNeeded(parent: ClassLoader): ClassLoader = { | ||
| val classUri = conf.get("spark.repl.class.uri", null) | ||
| private def addReplClassLoaderIfNeeded( | ||
| parent: ClassLoader, | ||
| sessionClassUri: Option[String]): ClassLoader = { | ||
| val classUri = sessionClassUri.getOrElse(conf.get("spark.repl.class.uri", null)) | ||
| if (classUri != null) { | ||
| logInfo("Using REPL class URI: " + classUri) | ||
| new ExecutorClassLoader(conf, env, classUri, parent, userClassPathFirst) | ||
|
|
@@ -1004,9 +1040,10 @@ private[spark] class Executor( | |
| * Visible for testing. | ||
| */ | ||
| private[executor] def updateDependencies( | ||
| newFiles: Map[String, Long], | ||
| newJars: Map[String, Long], | ||
| newArchives: Map[String, Long], | ||
| newFiles: immutable.Map[String, Long], | ||
| newJars: immutable.Map[String, Long], | ||
| newArchives: immutable.Map[String, Long], | ||
| state: IsolatedSessionState, | ||
| testStartLatch: Option[CountDownLatch] = None, | ||
| testEndLatch: Option[CountDownLatch] = None): Unit = { | ||
| lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) | ||
|
|
@@ -1015,14 +1052,15 @@ private[spark] class Executor( | |
| // For testing, so we can simulate a slow file download: | ||
| testStartLatch.foreach(_.countDown()) | ||
| // Fetch missing dependencies | ||
| for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { | ||
| for ((name, timestamp) <- newFiles if state.currentFiles.getOrElse(name, -1L) < timestamp) { | ||
| logInfo(s"Fetching $name with timestamp $timestamp") | ||
| // Fetch file with useCache mode, close cache for local mode. | ||
| Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf, | ||
| hadoopConf, timestamp, useCache = !isLocal) | ||
| currentFiles(name) = timestamp | ||
| state.currentFiles(name) = timestamp | ||
| } | ||
| for ((name, timestamp) <- newArchives if currentArchives.getOrElse(name, -1L) < timestamp) { | ||
| for ((name, timestamp) <- newArchives if | ||
| state.currentArchives.getOrElse(name, -1L) < timestamp) { | ||
| logInfo(s"Fetching $name with timestamp $timestamp") | ||
| val sourceURI = new URI(name) | ||
| val uriToDownload = UriBuilder.fromUri(sourceURI).fragment(null).build() | ||
|
|
@@ -1035,24 +1073,24 @@ private[spark] class Executor( | |
| s"Unpacking an archive $name from ${source.getAbsolutePath} to ${dest.getAbsolutePath}") | ||
| Utils.deleteRecursively(dest) | ||
| Utils.unpack(source, dest) | ||
| currentArchives(name) = timestamp | ||
| state.currentArchives(name) = timestamp | ||
| } | ||
| for ((name, timestamp) <- newJars) { | ||
| val localName = new URI(name).getPath.split("/").last | ||
| val currentTimeStamp = currentJars.get(name) | ||
| .orElse(currentJars.get(localName)) | ||
| val currentTimeStamp = state.currentJars.get(name) | ||
| .orElse(state.currentJars.get(localName)) | ||
| .getOrElse(-1L) | ||
| if (currentTimeStamp < timestamp) { | ||
| logInfo(s"Fetching $name with timestamp $timestamp") | ||
| // Fetch file with useCache mode, close cache for local mode. | ||
| Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf, | ||
| hadoopConf, timestamp, useCache = !isLocal) | ||
| currentJars(name) = timestamp | ||
| state.currentJars(name) = timestamp | ||
| // Add it to our class loader | ||
| val url = new File(SparkFiles.getRootDirectory(), localName).toURI.toURL | ||
| if (!urlClassLoader.getURLs().contains(url)) { | ||
| if (!state.urlClassLoader.getURLs().contains(url)) { | ||
| logInfo(s"Adding $url to class loader") | ||
| urlClassLoader.addURL(url) | ||
| state.urlClassLoader.addURL(url) | ||
| } | ||
| } | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Objects.hash(...) is bit simpler.