diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala index 449ba011c219..1d784d813d10 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala @@ -29,7 +29,7 @@ import scala.reflect.ClassTag import org.apache.commons.io.FileUtils import org.apache.hadoop.fs.{LocalFileSystem, Path => FSPath} -import org.apache.spark.{JobArtifactSet, SparkContext, SparkEnv} +import org.apache.spark.{JobArtifactState, SparkContext, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connect.artifact.util.ArtifactUtils @@ -56,15 +56,15 @@ import org.apache.spark.util.Utils class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging { import SparkConnectArtifactManager._ - private val sessionUUID = sessionHolder.session.sessionUUID // The base directory/URI where all artifacts are stored for this `sessionUUID`. val (artifactPath, artifactURI): (Path, String) = getArtifactDirectoryAndUriForSession(sessionHolder) // The base directory/URI where all class file artifacts are stored for this `sessionUUID`. val (classDir, classURI): (Path, String) = getClassfileDirectoryAndUriForSession(sessionHolder) + val state: JobArtifactState = + JobArtifactState(sessionHolder.session.sessionUUID, Option(classURI)) private val jarsList = new CopyOnWriteArrayList[Path] - private val jarsURI = new CopyOnWriteArrayList[String] private val pythonIncludeList = new CopyOnWriteArrayList[String] /** @@ -132,10 +132,16 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging } Files.move(serverLocalStagingPath, target) if (remoteRelativePath.startsWith(s"jars${File.separator}")) { + sessionHolder.session.sessionState.resourceLoader + .addJar(target.toString, state.uuid) jarsList.add(target) - jarsURI.add(artifactURI + "/" + remoteRelativePath.toString) } else if (remoteRelativePath.startsWith(s"pyfiles${File.separator}")) { - sessionHolder.session.sparkContext.addFile(target.toString) + sessionHolder.session.sparkContext.addFile( + target.toString, + recursive = false, + addedOnSubmit = false, + isArchive = false, + sessionUUID = state.uuid) val stringRemotePath = remoteRelativePath.toString if (stringRemotePath.endsWith(".zip") || stringRemotePath.endsWith( ".egg") || stringRemotePath.endsWith(".jar")) { @@ -144,35 +150,28 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging } else if (remoteRelativePath.startsWith(s"archives${File.separator}")) { val canonicalUri = fragment.map(UriBuilder.fromUri(target.toUri).fragment).getOrElse(target.toUri) - sessionHolder.session.sparkContext.addArchive(canonicalUri.toString) + sessionHolder.session.sparkContext.addFile( + canonicalUri.toString, + recursive = false, + addedOnSubmit = false, + isArchive = true, + sessionUUID = state.uuid) } else if (remoteRelativePath.startsWith(s"files${File.separator}")) { - sessionHolder.session.sparkContext.addFile(target.toString) + sessionHolder.session.sparkContext.addFile( + target.toString, + recursive = false, + addedOnSubmit = false, + isArchive = false, + sessionUUID = state.uuid) } } } - /** - * Returns a [[JobArtifactSet]] pointing towards the session-specific jars and class files. - */ - def jobArtifactSet: JobArtifactSet = { - val builder = Map.newBuilder[String, Long] - jarsURI.forEach { jar => - builder += jar -> 0 - } - - new JobArtifactSet( - uuid = Option(sessionUUID), - replClassDirUri = Option(classURI), - jars = builder.result(), - files = Map.empty, - archives = Map.empty) - } - /** * Returns a [[ClassLoader]] for session-specific jar/class file resources. */ def classloader: ClassLoader = { - val urls = jarsList.asScala.map(_.toUri.toURL) :+ classDir.toUri.toURL + val urls = getSparkConnectAddedJars :+ classDir.toUri.toURL new URLClassLoader(urls.toArray, Utils.getContextOrSparkClassLoader) } @@ -183,6 +182,12 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging logDebug( s"Cleaning up resources for session with userId: ${sessionHolder.userId} and " + s"sessionId: ${sessionHolder.sessionId}") + + // Clean up added files + sessionHolder.session.sparkContext.addedFiles.remove(state.uuid) + sessionHolder.session.sparkContext.addedArchives.remove(state.uuid) + sessionHolder.session.sparkContext.addedJars.remove(state.uuid) + // Clean up cached relations val blockManager = sessionHolder.session.sparkContext.env.blockManager blockManager.removeCache(sessionHolder.userId, sessionHolder.sessionId) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 56ef68abbc2f..332e960c25b7 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -88,11 +88,6 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio artifactManager.addArtifact(remoteRelativePath, serverLocalStagingPath, fragment) } - /** - * A [[JobArtifactSet]] for this SparkConnect session. - */ - def connectJobArtifactSet: JobArtifactSet = artifactManager.jobArtifactSet - /** * A [[ClassLoader]] for jar/class file resources specific to this SparkConnect session. */ @@ -114,8 +109,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio def withContextClassLoader[T](f: => T): T = { // Needed for deserializing and evaluating the UDF on the driver Utils.withContextClassLoader(classloader) { - // Needed for propagating the dependencies to the executors. - JobArtifactSet.withActive(connectJobArtifactSet) { + JobArtifactSet.withActiveJobArtifactState(artifactManager.state) { f } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala index 345e458cd2f0..199290327cf8 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala @@ -51,20 +51,6 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { super.afterEach() } - test("Jar artifacts are added to spark session") { - val copyDir = Utils.createTempDir().toPath - FileUtils.copyDirectory(artifactPath.toFile, copyDir.toFile) - val stagingPath = copyDir.resolve("smallJar.jar") - val remotePath = Paths.get("jars/smallJar.jar") - artifactManager.addArtifact(remotePath, stagingPath, None) - - val expectedPath = SparkConnectArtifactManager.artifactRootPath - .resolve(s"$sessionUUID/jars/smallJar.jar") - assert(expectedPath.toFile.exists()) - val jars = artifactManager.jobArtifactSet.jars - assert(jars.exists(_._1.contains(remotePath.toString))) - } - test("Class artifacts are added to the correct directory.") { val copyDir = Utils.createTempDir().toPath FileUtils.copyDirectory(artifactPath.toFile, copyDir.toFile) diff --git a/core/src/main/scala/org/apache/spark/JobArtifactSet.scala b/core/src/main/scala/org/apache/spark/JobArtifactSet.scala index 3e402b3b3302..7b6c18277bc4 100644 --- a/core/src/main/scala/org/apache/spark/JobArtifactSet.scala +++ b/core/src/main/scala/org/apache/spark/JobArtifactSet.scala @@ -20,6 +20,16 @@ package org.apache.spark import java.io.Serializable import java.util.Objects +/** + * Job artifact state. For example, Spark Connect client sets the state specifically + * for the current client. + * + * @param uuid UUID to use in the current context of jab artifact set. Usually this is from + * a Spark Connect client. + * @param replClassDirUri The URI for the directory that stores REPL classes. + */ +private[spark] case class JobArtifactState(uuid: String, replClassDirUri: Option[String]) + /** * Artifact set for a job. * This class is used to store session (i.e `SparkSession`) specific resources/artifacts. @@ -27,98 +37,83 @@ import java.util.Objects * When Spark Connect is used, this job-set points towards session-specific jars and class files. * Note that Spark Connect is not a requirement for using this class. * - * @param uuid An optional UUID for this session. If unset, a default session will be used. - * @param replClassDirUri An optional custom URI to point towards class files. + * @param state Job artifact state. * @param jars Jars belonging to this session. * @param files Files belonging to this session. * @param archives Archives belonging to this session. */ -class JobArtifactSet( - val uuid: Option[String], - val replClassDirUri: Option[String], +private[spark] class JobArtifactSet( + val state: Option[JobArtifactState], val jars: Map[String, Long], val files: Map[String, Long], val archives: Map[String, Long]) extends Serializable { - def withActive[T](f: => T): T = JobArtifactSet.withActive(this)(f) - override def hashCode(): Int = { - Objects.hash(uuid, replClassDirUri, jars.toSeq, files.toSeq, archives.toSeq) + Objects.hash(state, jars.toSeq, files.toSeq, archives.toSeq) } override def equals(obj: Any): Boolean = { obj match { case that: JobArtifactSet => - this.getClass == that.getClass && this.uuid == that.uuid && - this.replClassDirUri == that.replClassDirUri && this.jars.toSeq == that.jars.toSeq && + this.getClass == that.getClass && this.state == that.state && + this.jars.toSeq == that.jars.toSeq && this.files.toSeq == that.files.toSeq && this.archives.toSeq == that.archives.toSeq } } - } -object JobArtifactSet { - - private[this] val current = new ThreadLocal[Option[JobArtifactSet]] { - override def initialValue(): Option[JobArtifactSet] = None - } - - /** - * When Spark Connect isn't used, we default back to the shared resources. - * @param sc The active [[SparkContext]] - * @return A [[JobArtifactSet]] containing a copy of the jars/files/archives from the underlying - * [[SparkContext]] `sc`. - */ - def apply(sc: SparkContext): JobArtifactSet = { - new JobArtifactSet( - uuid = None, - replClassDirUri = sc.conf.getOption("spark.repl.class.uri"), - jars = sc.addedJars.toMap, - files = sc.addedFiles.toMap, - archives = sc.addedArchives.toMap) - } - private lazy val emptyJobArtifactSet = new JobArtifactSet( - None, - None, - Map.empty, - Map.empty, - Map.empty) +private[spark] object JobArtifactSet { + // For testing. + val emptyJobArtifactSet: JobArtifactSet = new JobArtifactSet( + None, Map.empty, Map.empty, Map.empty) + // For testing. + def defaultJobArtifactSet: JobArtifactSet = SparkContext.getActive.map( + getActiveOrDefault).getOrElse(emptyJobArtifactSet) - /** - * Empty artifact set for use in tests. - */ - private[spark] def apply(): JobArtifactSet = emptyJobArtifactSet + private[this] val currentClientSessionState: ThreadLocal[Option[JobArtifactState]] = + new ThreadLocal[Option[JobArtifactState]] { + override def initialValue(): Option[JobArtifactState] = None + } - /** - * Used for testing. Returns artifacts from [[SparkContext]] if one exists or otherwise, an - * empty set. - */ - private[spark] def defaultArtifactSet(): JobArtifactSet = { - SparkContext.getActive.map(sc => JobArtifactSet(sc)).getOrElse(JobArtifactSet()) - } + def getCurrentClientSessionState: Option[JobArtifactState] = currentClientSessionState.get() /** - * Execute a block of code with the currently active [[JobArtifactSet]]. - * @param active - * @param block - * @tparam T + * Set the Spark Connect specific information in the active client to the underlying + * [[JobArtifactSet]]. + * + * @param state Job artifact state. + * @return the result from the function applied with [[JobArtifactSet]] specific to + * the active client. */ - def withActive[T](active: JobArtifactSet)(block: => T): T = { - val old = current.get() - current.set(Option(active)) + def withActiveJobArtifactState[T](state: JobArtifactState)(block: => T): T = { + val oldState = currentClientSessionState.get() + currentClientSessionState.set(Option(state)) try block finally { - current.set(old) + currentClientSessionState.set(oldState) } } /** - * Optionally returns the active [[JobArtifactSet]]. - */ - def active: Option[JobArtifactSet] = current.get() - - /** - * Return the active [[JobArtifactSet]] or creates the default set using the [[SparkContext]]. - * @param sc + * When Spark Connect isn't used, we default back to the shared resources. + * + * @param sc The active [[SparkContext]] + * @return A [[JobArtifactSet]] containing a copy of the jars/files/archives. + * If there is an active client, it sets the information from them. + * Otherwise, it falls back to the default in the [[SparkContext]]. */ - def getActiveOrDefault(sc: SparkContext): JobArtifactSet = active.getOrElse(JobArtifactSet(sc)) + def getActiveOrDefault(sc: SparkContext): JobArtifactSet = { + val maybeState = currentClientSessionState.get().map(s => s.copy( + replClassDirUri = s.replClassDirUri.orElse(sc.conf.getOption("spark.repl.class.uri")))) + new JobArtifactSet( + state = maybeState, + jars = maybeState + .map(s => sc.addedJars.getOrElse(s.uuid, sc.allAddedJars)) + .getOrElse(sc.allAddedJars).toMap, + files = maybeState + .map(s => sc.addedFiles.getOrElse(s.uuid, sc.allAddedFiles)) + .getOrElse(sc.allAddedFiles).toMap, + archives = maybeState + .map(s => sc.addedArchives.getOrElse(s.uuid, sc.allAddedArchives)) + .getOrElse(sc.allAddedArchives).toMap) + } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 58f8310da707..6f99afe316cf 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -26,6 +26,7 @@ import javax.ws.rs.core.UriBuilder import scala.collection.JavaConverters._ import scala.collection.Map +import scala.collection.concurrent.{Map => ScalaConcurrentMap} import scala.collection.immutable import scala.collection.mutable.HashMap import scala.language.implicitConversions @@ -290,10 +291,18 @@ class SparkContext(config: SparkConf) extends Logging { private[spark] def env: SparkEnv = _env - // Used to store a URL for each static file/jar together with the file's local timestamp - private[spark] val addedFiles = new ConcurrentHashMap[String, Long]().asScala - private[spark] val addedArchives = new ConcurrentHashMap[String, Long]().asScala - private[spark] val addedJars = new ConcurrentHashMap[String, Long]().asScala + // Used to store session UUID with a URL for each static file/jar together and + // the file's local timestamp. It's session uuid -> (URL -> timestamp). + private[spark] val addedFiles = new ConcurrentHashMap[ + String, ScalaConcurrentMap[String, Long]]().asScala + private[spark] val addedArchives = new ConcurrentHashMap[ + String, ScalaConcurrentMap[String, Long]]().asScala + private[spark] val addedJars = new ConcurrentHashMap[ + String, ScalaConcurrentMap[String, Long]]().asScala + + private[spark] def allAddedFiles = addedFiles.values.flatten.toMap + private[spark] def allAddedArchives = addedArchives.values.flatten.toMap + private[spark] def allAddedJars = addedJars.values.flatten.toMap // Keeps track of all persisted RDDs private[spark] val persistentRdds = { @@ -515,22 +524,22 @@ class SparkContext(config: SparkConf) extends Logging { // Add each JAR given through the constructor if (jars != null) { jars.foreach(jar => addJar(jar, true)) - if (addedJars.nonEmpty) { - _conf.set("spark.app.initial.jar.urls", addedJars.keys.toSeq.mkString(",")) + if (allAddedJars.nonEmpty) { + _conf.set("spark.app.initial.jar.urls", allAddedJars.keys.toSeq.mkString(",")) } } if (files != null) { files.foreach(file => addFile(file, false, true)) - if (addedFiles.nonEmpty) { - _conf.set("spark.app.initial.file.urls", addedFiles.keys.toSeq.mkString(",")) + if (allAddedFiles.nonEmpty) { + _conf.set("spark.app.initial.file.urls", allAddedFiles.keys.toSeq.mkString(",")) } } if (archives != null) { archives.foreach(file => addFile(file, false, true, isArchive = true)) - if (addedArchives.nonEmpty) { - _conf.set("spark.app.initial.archive.urls", addedArchives.keys.toSeq.mkString(",")) + if (allAddedArchives.nonEmpty) { + _conf.set("spark.app.initial.archive.urls", allAddedArchives.keys.toSeq.mkString(",")) } } @@ -1675,7 +1684,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * Returns a list of file paths that are added to resources. */ - def listFiles(): Seq[String] = addedFiles.keySet.toSeq + def listFiles(): Seq[String] = allAddedFiles.keySet.toSeq /** * :: Experimental :: @@ -1705,7 +1714,7 @@ class SparkContext(config: SparkConf) extends Logging { * @since 3.1.0 */ @Experimental - def listArchives(): Seq[String] = addedArchives.keySet.toSeq + def listArchives(): Seq[String] = allAddedArchives.keySet.toSeq /** * Add a file to be downloaded with this Spark job on every node. @@ -1724,8 +1733,12 @@ class SparkContext(config: SparkConf) extends Logging { addFile(path, recursive, false) } - private def addFile( - path: String, recursive: Boolean, addedOnSubmit: Boolean, isArchive: Boolean = false + private[spark] def addFile( + path: String, + recursive: Boolean, + addedOnSubmit: Boolean, + isArchive: Boolean = false, + sessionUUID: String = "default" ): Unit = { val uri = Utils.resolveURI(path) val schemeCorrectedURI = uri.getScheme match { @@ -1762,7 +1775,11 @@ class SparkContext(config: SparkConf) extends Logging { } val timestamp = if (addedOnSubmit) startTime else System.currentTimeMillis - if (!isArchive && addedFiles.putIfAbsent(key, timestamp).isEmpty) { + if ( + !isArchive && + addedFiles + .getOrElseUpdate(sessionUUID, new ConcurrentHashMap[String, Long]().asScala) + .putIfAbsent(key, timestamp).isEmpty) { logInfo(s"Added file $path at $key with timestamp $timestamp") // Fetch the file locally so that closures which are run on the driver can still use the // SparkFiles API to access files. @@ -1771,7 +1788,9 @@ class SparkContext(config: SparkConf) extends Logging { postEnvironmentUpdate() } else if ( isArchive && - addedArchives.putIfAbsent( + addedArchives + .getOrElseUpdate(sessionUUID, new ConcurrentHashMap[String, Long]().asScala) + .putIfAbsent( UriBuilder.fromUri(new URI(key)).fragment(uri.getFragment).build().toString, timestamp).isEmpty) { logInfo(s"Added archive $path at $key with timestamp $timestamp") @@ -2064,7 +2083,8 @@ class SparkContext(config: SparkConf) extends Logging { addJar(path, false) } - private def addJar(path: String, addedOnSubmit: Boolean): Unit = { + private[spark] def addJar( + path: String, addedOnSubmit: Boolean, sessionUUID: String = "default"): Unit = { def addLocalJarFile(file: File): Seq[String] = { try { if (!file.exists()) { @@ -2137,7 +2157,9 @@ class SparkContext(config: SparkConf) extends Logging { } if (keys.nonEmpty) { val timestamp = if (addedOnSubmit) startTime else System.currentTimeMillis - val (added, existed) = keys.partition(addedJars.putIfAbsent(_, timestamp).isEmpty) + val (added, existed) = keys.partition(addedJars + .getOrElseUpdate(sessionUUID, new ConcurrentHashMap[String, Long]().asScala) + .putIfAbsent(_, timestamp).isEmpty) if (added.nonEmpty) { val jarMessage = if (scheme != "ivy") "JAR" else "dependency jars of Ivy URI" logInfo(s"Added $jarMessage $path at ${added.mkString(",")} with timestamp $timestamp") @@ -2155,7 +2177,7 @@ class SparkContext(config: SparkConf) extends Logging { /** * Returns a list of jar files that are added to resources. */ - def listJars(): Seq[String] = addedJars.keySet.toSeq + def listJars(): Seq[String] = allAddedJars.keySet.toSeq /** * When stopping SparkContext inside Spark components, it's easy to cause dead-lock since Spark @@ -2738,9 +2760,9 @@ class SparkContext(config: SparkConf) extends Logging { private def postEnvironmentUpdate(): Unit = { if (taskScheduler != null) { val schedulingMode = getSchedulingMode.toString - val addedJarPaths = addedJars.keys.toSeq - val addedFilePaths = addedFiles.keys.toSeq - val addedArchivePaths = addedArchives.keys.toSeq + val addedJarPaths = allAddedJars.keys.toSeq + val addedFilePaths = allAddedFiles.keys.toSeq + val addedArchivePaths = allAddedArchives.keys.toSeq val environmentDetails = SparkEnv.environmentDetails(conf, hadoopConfiguration, schedulingMode, addedJarPaths, addedFilePaths, addedArchivePaths, env.metricsSystem.metricsProperties.asScala.toMap) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 52cb15de83d3..9d7b941db0af 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -165,6 +165,9 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( envVars.put("PYTHON_FAULTHANDLER_DIR", BasePythonRunner.faultHandlerLogDir.toString) } + val sessionUUID = JobArtifactSet.getCurrentClientSessionState.map(_.uuid).getOrElse("default") + envVars.put("SPARK_CONNECT_SESSION_UUID", sessionUUID) + val (worker: Socket, pid: Option[Int]) = env.createPythonWorker( pythonExec, envVars.asScala.toMap) // Whether is the worker released into idle pool or closed. When any codes try to release or diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 69a74146fad1..d6dcd906d92f 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.python -import java.io.{DataInputStream, DataOutputStream, EOFException, InputStream} +import java.io.{DataInputStream, DataOutputStream, EOFException, File, InputStream} import java.net.{InetAddress, ServerSocket, Socket, SocketException} import java.util.Arrays import java.util.concurrent.TimeUnit @@ -157,6 +157,10 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String // Create and start the worker val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", workerModule)) + val sessionId = envVars.getOrElse("SPARK_CONNECT_SESSION_UUID", "deafult") + if (sessionId != "default") { + pb.directory(new File(SparkFiles.getRootDirectory(), sessionId)) + } val workerEnv = pb.environment() workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) @@ -210,6 +214,10 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String // Create and start the daemon val command = Arrays.asList(pythonExec, "-m", daemonModule) val pb = new ProcessBuilder(command) + val sessionId = envVars.getOrElse("SPARK_CONNECT_SESSION_UUID", "deafult") + if (sessionId != "default") { + pb.directory(new File(SparkFiles.getRootDirectory(), sessionId)) + } val workerEnv = pb.environment() workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 85fd66543cdd..b30569dc9641 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -169,25 +169,25 @@ private[spark] class Executor( private val systemLoader = Utils.getContextOrSparkClassLoader - private def newSessionState( - sessionUUID: String, - classUri: Option[String]): IsolatedSessionState = { + private def newSessionState(jobArtifactState: JobArtifactState): IsolatedSessionState = { val currentFiles = new HashMap[String, Long] val currentJars = new HashMap[String, Long] val currentArchives = new HashMap[String, Long] val urlClassLoader = createClassLoader(currentJars) - val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader, classUri) + val replClassLoader = addReplClassLoaderIfNeeded( + urlClassLoader, jobArtifactState.replClassDirUri) new IsolatedSessionState( - sessionUUID, urlClassLoader, replClassLoader, currentFiles, currentJars, currentArchives) + jobArtifactState.uuid, urlClassLoader, replClassLoader, + currentFiles, currentJars, currentArchives) } // Classloader isolation // The default isolation group - val defaultSessionState = newSessionState("default", None) + val defaultSessionState = newSessionState(JobArtifactState("default", None)) val isolatedSessionCache = CacheBuilder.newBuilder() .maximumSize(100) - .expireAfterAccess(5, TimeUnit.MINUTES) + .expireAfterAccess(30, TimeUnit.MINUTES) .build[String, IsolatedSessionState] // Set the classloader for serializer @@ -513,11 +513,10 @@ private[spark] class Executor( override def run(): Unit = { // Classloader isolation - val isolatedSessionUUID: Option[String] = taskDescription.artifacts.uuid - val isolatedSession = isolatedSessionUUID match { - case Some(uuid) => isolatedSessionCache.get( - uuid, - () => newSessionState(uuid, taskDescription.artifacts.replClassDirUri)) + val isolatedSession = taskDescription.artifacts.state match { + case Some(jobArtifactState) => isolatedSessionCache.get( + jobArtifactState.uuid, + () => newSessionState(jobArtifactState)) case _ => defaultSessionState } @@ -1054,12 +1053,22 @@ private[spark] class Executor( try { // For testing, so we can simulate a slow file download: testStartLatch.foreach(_.countDown()) + + // If the session ID was specified from SparkSession, it's from a Spark Connect client. + // Specify a dedicated directory for Spark Connect client. + lazy val root = if (state.sessionUUID != "default") { + val newDest = new File(SparkFiles.getRootDirectory(), state.sessionUUID) + newDest.mkdir() + newDest + } else { + new File(SparkFiles.getRootDirectory()) + } + // Fetch missing dependencies for ((name, timestamp) <- newFiles if state.currentFiles.getOrElse(name, -1L) < timestamp) { logInfo(s"Fetching $name with timestamp $timestamp") // Fetch file with useCache mode, close cache for local mode. - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf, - hadoopConf, timestamp, useCache = !isLocal) + Utils.fetchFile(name, root, conf, hadoopConf, timestamp, useCache = !isLocal) state.currentFiles(name) = timestamp } for ((name, timestamp) <- newArchives if @@ -1070,7 +1079,7 @@ private[spark] class Executor( val source = Utils.fetchFile(uriToDownload.toString, Utils.createTempDir(), conf, hadoopConf, timestamp, useCache = !isLocal, shouldUntar = false) val dest = new File( - SparkFiles.getRootDirectory(), + root, if (sourceURI.getFragment != null) sourceURI.getFragment else source.getName) logInfo( s"Unpacking an archive $name from ${source.getAbsolutePath} to ${dest.getAbsolutePath}") @@ -1086,11 +1095,11 @@ private[spark] class Executor( if (currentTimeStamp < timestamp) { logInfo(s"Fetching $name with timestamp $timestamp") // Fetch file with useCache mode, close cache for local mode. - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory()), conf, + Utils.fetchFile(name, root, conf, hadoopConf, timestamp, useCache = !isLocal) state.currentJars(name) = timestamp // Add it to our class loader - val url = new File(SparkFiles.getRootDirectory(), localName).toURI.toURL + val url = new File(root, localName).toURI.toURL if (!state.urlClassLoader.getURLs().contains(url)) { logInfo(s"Adding $url to class loader") state.urlClassLoader.addURL(url) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala index 0e30c165457d..753736735e40 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala @@ -26,7 +26,7 @@ import scala.collection.JavaConverters._ import scala.collection.immutable import scala.collection.mutable.{ArrayBuffer, HashMap, Map} -import org.apache.spark.JobArtifactSet +import org.apache.spark.{JobArtifactSet, JobArtifactState} import org.apache.spark.resource.ResourceInformation import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} @@ -133,8 +133,11 @@ private[spark] object TaskDescription { private def deserializeArtifacts(dataIn: DataInputStream): JobArtifactSet = { new JobArtifactSet( - uuid = deserializeOptionString(dataIn), - replClassDirUri = deserializeOptionString(dataIn), + state = deserializeOptionString(dataIn).map { uuid => + JobArtifactState( + uuid = uuid, + replClassDirUri = deserializeOptionString(dataIn)) + }, jars = immutable.Map(deserializeStringLongMap(dataIn).toSeq: _*), files = immutable.Map(deserializeStringLongMap(dataIn).toSeq: _*), archives = immutable.Map(deserializeStringLongMap(dataIn).toSeq: _*)) @@ -148,8 +151,10 @@ private[spark] object TaskDescription { } private def serializeArtifacts(artifacts: JobArtifactSet, dataOut: DataOutputStream): Unit = { - serializeOptionString(artifacts.uuid, dataOut) - serializeOptionString(artifacts.replClassDirUri, dataOut) + serializeOptionString(artifacts.state.map(_.uuid), dataOut) + artifacts.state.foreach { state => + serializeOptionString(state.replClassDirUri, dataOut) + } serializeStringLongMap(Map(artifacts.jars.toSeq: _*), dataOut) serializeStringLongMap(Map(artifacts.files.toSeq: _*), dataOut) serializeStringLongMap(Map(artifacts.archives.toSeq: _*), dataOut) diff --git a/core/src/test/scala/org/apache/spark/JobArtifactSetSuite.scala b/core/src/test/scala/org/apache/spark/JobArtifactSetSuite.scala index df09de1483ed..66d02e8b511a 100644 --- a/core/src/test/scala/org/apache/spark/JobArtifactSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobArtifactSetSuite.scala @@ -17,14 +17,33 @@ package org.apache.spark -import java.io.File +import java.io.{File, FileInputStream, FileOutputStream} +import java.util.zip.{ZipEntry, ZipOutputStream} + +import org.apache.commons.io.IOUtils + class JobArtifactSetSuite extends SparkFunSuite with LocalSparkContext { + + private def createZipFile(inFile: String, outFile: String): Unit = { + val fileToZip = new File(inFile) + val fis = new FileInputStream(fileToZip) + val fos = new FileOutputStream(outFile) + val zipOut = new ZipOutputStream(fos) + val zipEntry = new ZipEntry(fileToZip.getName) + zipOut.putNextEntry(zipEntry) + IOUtils.copy(fis, zipOut) + IOUtils.closeQuietly(fis) + IOUtils.closeQuietly(zipOut) + } + test("JobArtifactSet uses resources from SparkContext") { withTempDir { dir => val jarPath = File.createTempFile("testJar", ".jar", dir).getAbsolutePath val filePath = File.createTempFile("testFile", ".txt", dir).getAbsolutePath - val archivePath = File.createTempFile("testZip", ".zip", dir).getAbsolutePath + val fileToZip = File.createTempFile("testFile", "", dir).getAbsolutePath + val archivePath = s"$fileToZip.zip" + createZipFile(fileToZip, archivePath) val conf = new SparkConf() .setAppName("test") @@ -34,54 +53,37 @@ class JobArtifactSetSuite extends SparkFunSuite with LocalSparkContext { sc.addJar(jarPath) sc.addFile(filePath) - sc.addJar(archivePath) + sc.addArchive(archivePath) val artifacts = JobArtifactSet.getActiveOrDefault(sc) - assert(artifacts.archives == sc.addedArchives) - assert(artifacts.files == sc.addedFiles) - assert(artifacts.jars == sc.addedJars) - assert(artifacts.replClassDirUri.contains("dummyUri")) + assert(artifacts.archives == sc.allAddedArchives) + assert(artifacts.files == sc.allAddedFiles) + assert(artifacts.jars == sc.allAddedJars) + assert(artifacts.state.isEmpty) } } test("The active JobArtifactSet is fetched if set") { withTempDir { dir => - val jarPath = File.createTempFile("testJar", ".jar", dir).getAbsolutePath - val filePath = File.createTempFile("testFile", ".txt", dir).getAbsolutePath - val archivePath = File.createTempFile("testZip", ".zip", dir).getAbsolutePath - val conf = new SparkConf() .setAppName("test") .setMaster("local") .set("spark.repl.class.uri", "dummyUri") sc = new SparkContext(conf) - sc.addJar(jarPath) - sc.addFile(filePath) - sc.addJar(archivePath) - - val artifactSet1 = new JobArtifactSet( - Some("123"), - Some("abc"), - Map("a" -> 1), - Map("b" -> 2), - Map("c" -> 3) - ) - - val artifactSet2 = new JobArtifactSet( - Some("789"), - Some("hjk"), - Map("x" -> 7), - Map("y" -> 8), - Map("z" -> 9) - ) + val artifactState1 = JobArtifactState("123", Some("abc")) + val artifactState2 = JobArtifactState("789", Some("hjk")) - JobArtifactSet.withActive(artifactSet1) { - JobArtifactSet.withActive(artifactSet2) { - assert(JobArtifactSet.getActiveOrDefault(sc) == artifactSet2) + JobArtifactSet.withActiveJobArtifactState(artifactState1) { + JobArtifactSet.withActiveJobArtifactState(artifactState2) { + assert(JobArtifactSet.getActiveOrDefault(sc).state.get == artifactState2) + assert(JobArtifactSet.getActiveOrDefault(sc).state.get.replClassDirUri.get == "hjk") } - assert(JobArtifactSet.getActiveOrDefault(sc) == artifactSet1) + assert(JobArtifactSet.getActiveOrDefault(sc).state.get == artifactState1) + assert(JobArtifactSet.getActiveOrDefault(sc).state.get.replClassDirUri.get == "abc") } + + assert(JobArtifactSet.getActiveOrDefault(sc).state.isEmpty) } } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 93d0d33101af..4145975741bc 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -276,8 +276,8 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu sc.addJar(badURL) } assert(e2.getMessage.contains(badURL)) - assert(sc.addedFiles.isEmpty) - assert(sc.addedJars.isEmpty) + assert(sc.allAddedFiles.isEmpty) + assert(sc.allAddedJars.isEmpty) } } finally { sc.stop() diff --git a/core/src/test/scala/org/apache/spark/executor/ClassLoaderIsolationSuite.scala b/core/src/test/scala/org/apache/spark/executor/ClassLoaderIsolationSuite.scala index 72ee0e96fd01..479429e3d3c5 100644 --- a/core/src/test/scala/org/apache/spark/executor/ClassLoaderIsolationSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ClassLoaderIsolationSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.executor import scala.util.Properties -import org.apache.spark.{JobArtifactSet, LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.{JobArtifactSet, JobArtifactState, LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.util.Utils class ClassLoaderIsolationSuite extends SparkFunSuite with LocalSparkContext { @@ -51,14 +51,14 @@ class ClassLoaderIsolationSuite extends SparkFunSuite with LocalSparkContext { // TestHelloV2's test method returns '2' val artifactSetWithHelloV2 = new JobArtifactSet( - uuid = Some("hello2"), - replClassDirUri = None, + Some(JobArtifactState(uuid = "hello2", replClassDirUri = None)), jars = Map(jar2 -> 1L), files = Map.empty, archives = Map.empty ) + sc.addJar(jar2, false, artifactSetWithHelloV2.state.get.uuid) - JobArtifactSet.withActive(artifactSetWithHelloV2) { + JobArtifactSet.withActiveJobArtifactState(artifactSetWithHelloV2.state.get) { sc.parallelize(1 to 1).foreach { i => val cls = Utils.classForName("com.example.Hello$") val module = cls.getField("MODULE$").get(null) @@ -71,14 +71,14 @@ class ClassLoaderIsolationSuite extends SparkFunSuite with LocalSparkContext { // TestHelloV3's test method returns '3' val artifactSetWithHelloV3 = new JobArtifactSet( - uuid = Some("hello3"), - replClassDirUri = None, + Some(JobArtifactState(uuid = "hello3", replClassDirUri = None)), jars = Map(jar3 -> 1L), files = Map.empty, archives = Map.empty ) + sc.addJar(jar3, false, artifactSetWithHelloV3.state.get.uuid) - JobArtifactSet.withActive(artifactSetWithHelloV3) { + JobArtifactSet.withActiveJobArtifactState(artifactSetWithHelloV3.state.get) { sc.parallelize(1 to 1).foreach { i => val cls = Utils.classForName("com.example.Hello$") val module = cls.getField("MODULE$").get(null) @@ -91,14 +91,14 @@ class ClassLoaderIsolationSuite extends SparkFunSuite with LocalSparkContext { // Should not be able to see any "Hello" class if they're excluded from the artifact set val artifactSetWithoutHello = new JobArtifactSet( - uuid = Some("Jar 1"), - replClassDirUri = None, + Some(JobArtifactState(uuid = "Jar 1", replClassDirUri = None)), jars = Map(jar1 -> 1L), files = Map.empty, archives = Map.empty ) + sc.addJar(jar1, false, artifactSetWithoutHello.state.get.uuid) - JobArtifactSet.withActive(artifactSetWithoutHello) { + JobArtifactSet.withActiveJobArtifactState(artifactSetWithoutHello.state.get) { sc.parallelize(1 to 1).foreach { i => try { Utils.classForName("com.example.Hello$") diff --git a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala index 9c61b1f8c27a..0dcc7c7f9b4c 100644 --- a/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/CoarseGrainedExecutorBackendSuite.scala @@ -306,7 +306,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite // We don't really verify the data, just pass it around. val data = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4)) val taskDescription = new TaskDescription(taskId, 2, "1", "TASK 1000000", 19, - 1, JobArtifactSet(), new Properties, 1, + 1, JobArtifactSet.emptyJobArtifactSet, new Properties, 1, Map(GPU -> new ResourceInformation(GPU, Array("0", "1"))), data) val serializedTaskDescription = TaskDescription.encode(taskDescription) backend.rpcEnv.setupEndpoint("Executor 1", backend) @@ -422,7 +422,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite // Fake tasks with different taskIds. val taskDescriptions = (1 to numTasks).map { taskId => new TaskDescription(taskId, 2, "1", s"TASK $taskId", 19, - 1, JobArtifactSet(), new Properties, 1, + 1, JobArtifactSet.emptyJobArtifactSet, new Properties, 1, Map(GPU -> new ResourceInformation(GPU, Array("0", "1"))), data) } assert(taskDescriptions.length == numTasks) @@ -511,7 +511,7 @@ class CoarseGrainedExecutorBackendSuite extends SparkFunSuite // Fake tasks with different taskIds. val taskDescriptions = (1 to numTasks).map { taskId => new TaskDescription(taskId, 2, "1", s"TASK $taskId", 19, - 1, JobArtifactSet(), new Properties, 1, + 1, JobArtifactSet.emptyJobArtifactSet, new Properties, 1, Map(GPU -> new ResourceInformation(GPU, Array("0", "1"))), data) } assert(taskDescriptions.length == numTasks) diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 72a6c7555c77..49e19dd2a00e 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -623,7 +623,7 @@ class ExecutorSuite extends SparkFunSuite numPartitions = 1, locs = Seq(), outputId = 0, - JobArtifactSet(), + JobArtifactSet.emptyJobArtifactSet, localProperties = new Properties(), serializedTaskMetrics = serializedTaskMetrics ) @@ -639,7 +639,7 @@ class ExecutorSuite extends SparkFunSuite name = "", index = 0, partitionId = 0, - JobArtifactSet(), + JobArtifactSet.emptyJobArtifactSet, properties = new Properties, cpus = 1, resources = immutable.Map[String, ResourceInformation](), diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala index cb82c2e0a45e..bf5e9d96cd80 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala @@ -257,7 +257,8 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo val taskResources = Map(GPU -> new ResourceInformation(GPU, Array("0"))) val taskCpus = 1 val taskDescs: Seq[Seq[TaskDescription]] = Seq(Seq(new TaskDescription(1, 0, "1", - "t1", 0, 1, JobArtifactSet(), new Properties(), taskCpus, taskResources, bytebuffer))) + "t1", 0, 1, JobArtifactSet.emptyJobArtifactSet, new Properties(), + taskCpus, taskResources, bytebuffer))) val ts = backend.getTaskSchedulerImpl() when(ts.resourceOffers(any[IndexedSeq[WorkerOffer]], any[Boolean])).thenReturn(taskDescs) @@ -363,7 +364,8 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo val taskResources = Map(GPU -> new ResourceInformation(GPU, Array("0"))) val taskCpus = 1 val taskDescs: Seq[Seq[TaskDescription]] = Seq(Seq(new TaskDescription(1, 0, "1", - "t1", 0, 1, JobArtifactSet(), new Properties(), taskCpus, taskResources, bytebuffer))) + "t1", 0, 1, JobArtifactSet.emptyJobArtifactSet, new Properties(), + taskCpus, taskResources, bytebuffer))) val ts = backend.getTaskSchedulerImpl() when(ts.resourceOffers(any[IndexedSeq[WorkerOffer]], any[Boolean])).thenReturn(taskDescs) @@ -455,7 +457,8 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo // Task cpus can be different from default resource profile when TaskResourceProfile is used. val taskCpus = 2 val taskDescs: Seq[Seq[TaskDescription]] = Seq(Seq(new TaskDescription(1, 0, "1", - "t1", 0, 1, JobArtifactSet(), new Properties(), taskCpus, Map.empty, bytebuffer))) + "t1", 0, 1, JobArtifactSet.emptyJobArtifactSet, new Properties(), + taskCpus, Map.empty, bytebuffer))) when(ts.resourceOffers(any[IndexedSeq[WorkerOffer]], any[Boolean])).thenReturn(taskDescs) backend.driverEndpoint.send(ReviveOffers) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 3aeb52cd37d0..c7e4994e328f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -491,7 +491,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti partitions: Array[Int], func: (TaskContext, Iterator[_]) => _ = jobComputeFunc, listener: JobListener = jobListener, - artifacts: JobArtifactSet = JobArtifactSet(sc), + artifacts: JobArtifactSet = JobArtifactSet.getActiveOrDefault(sc), properties: Properties = null): Int = { val jobId = scheduler.nextJobId.getAndIncrement() runEvent(JobSubmitted(jobId, rdd, func, partitions, CallSite("", ""), listener, artifacts, @@ -503,7 +503,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti private def submitMapStage( shuffleDep: ShuffleDependency[_, _, _], listener: JobListener = jobListener, - artifacts: JobArtifactSet = JobArtifactSet(sc)): Int = { + artifacts: JobArtifactSet = JobArtifactSet.getActiveOrDefault(sc)): Int = { val jobId = scheduler.nextJobId.getAndIncrement() runEvent(MapStageSubmitted(jobId, shuffleDep, CallSite("", ""), listener, artifacts)) jobId diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index 2f65b608a46d..28049d9955e3 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -30,7 +30,7 @@ class FakeTask( serializedTaskMetrics: Array[Byte] = SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array(), isBarrier: Boolean = false) - extends Task[Int](stageId, 0, partitionId, 1, JobArtifactSet.defaultArtifactSet(), + extends Task[Int](stageId, 0, partitionId, 1, JobArtifactSet.defaultJobArtifactSet, new Properties, serializedTaskMetrics, isBarrier = isBarrier) { override def runTask(context: TaskContext): Int = 0 @@ -96,7 +96,7 @@ object FakeTask { val tasks = Array.tabulate[Task[_]](numTasks) { i => new ShuffleMapTask(stageId, stageAttemptId, null, new Partition { override def index: Int = i - }, 1, prefLocs(i), JobArtifactSet.defaultArtifactSet(), new Properties, + }, 1, prefLocs(i), JobArtifactSet.defaultJobArtifactSet, new Properties, SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array()) } new TaskSet(tasks, stageId, stageAttemptId, priority = priority, null, diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala index b1e1e9c50a26..0b6f27692057 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala @@ -25,7 +25,7 @@ import org.apache.spark.{JobArtifactSet, TaskContext} * A Task implementation that fails to serialize. */ private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) - extends Task[Array[Byte]](stageId, 0, 0, 1, JobArtifactSet()) { + extends Task[Array[Byte]](stageId, 0, 0, 1, JobArtifactSet.emptyJobArtifactSet) { override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte] override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]() diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index f350e3cda51c..54a42c1a6618 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -70,7 +70,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark val func = (c: TaskContext, i: Iterator[String]) => i.next() val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func)))) val task = new ResultTask[String, String]( - 0, 0, taskBinary, rdd.partitions(0), 1, Seq.empty, 0, JobArtifactSet(sc), new Properties, + 0, 0, taskBinary, rdd.partitions(0), 1, Seq.empty, 0, + JobArtifactSet.getActiveOrDefault(sc), new Properties, closureSerializer.serialize(TaskMetrics.registered).array()) intercept[RuntimeException] { task.run(0, 0, null, 1, null, Option.empty) @@ -92,7 +93,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark val func = (c: TaskContext, i: Iterator[String]) => i.next() val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func)))) val task = new ResultTask[String, String]( - 0, 0, taskBinary, rdd.partitions(0), 1, Seq.empty, 0, JobArtifactSet(sc), new Properties, + 0, 0, taskBinary, rdd.partitions(0), 1, Seq.empty, 0, + JobArtifactSet.getActiveOrDefault(sc), new Properties, closureSerializer.serialize(TaskMetrics.registered).array()) intercept[RuntimeException] { task.run(0, 0, null, 1, null, Option.empty) @@ -160,7 +162,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark }) val e = intercept[TaskContextSuite.FakeTaskFailureException] { - context.runTaskWithListeners(new Task[Int](0, 0, 0, 1, JobArtifactSet(), + context.runTaskWithListeners(new Task[Int](0, 0, 0, 1, JobArtifactSet.emptyJobArtifactSet, serializedTaskMetrics = Array.empty) { override def runTask(context: TaskContext): Int = { throw new TaskContextSuite.FakeTaskFailureException @@ -192,7 +194,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark }) val e = intercept[TaskContextSuite.FakeTaskFailureException] { - context.runTaskWithListeners(new Task[Int](0, 0, 0, 1, JobArtifactSet(), + context.runTaskWithListeners(new Task[Int](0, 0, 0, 1, JobArtifactSet.emptyJobArtifactSet, serializedTaskMetrics = Array.empty) { override def runTask(context: TaskContext): Int = { throw new TaskContextSuite.FakeTaskFailureException @@ -224,7 +226,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark }) val e = intercept[TaskCompletionListenerException] { - context.runTaskWithListeners(new Task[Int](0, 0, 0, 1, JobArtifactSet(), + context.runTaskWithListeners(new Task[Int](0, 0, 0, 1, JobArtifactSet.emptyJobArtifactSet, serializedTaskMetrics = Array.empty) { override def runTask(context: TaskContext): Int = 0 }) @@ -255,7 +257,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark }) val e = intercept[TaskCompletionListenerException] { - context.runTaskWithListeners(new Task[Int](0, 0, 0, 1, JobArtifactSet(), + context.runTaskWithListeners(new Task[Int](0, 0, 0, 1, JobArtifactSet.emptyJobArtifactSet, serializedTaskMetrics = Array.empty) { override def runTask(context: TaskContext): Int = 0 }) @@ -288,7 +290,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark }) val e = intercept[TaskCompletionListenerException] { - context.runTaskWithListeners(new Task[Int](0, 0, 0, 1, JobArtifactSet(), + context.runTaskWithListeners(new Task[Int](0, 0, 0, 1, JobArtifactSet.emptyJobArtifactSet, serializedTaskMetrics = Array.empty) { override def runTask(context: TaskContext): Int = 0 }) @@ -321,7 +323,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark }) val e = intercept[TaskContextSuite.FakeTaskFailureException] { - context.runTaskWithListeners(new Task[Int](0, 0, 0, 1, JobArtifactSet(), + context.runTaskWithListeners(new Task[Int](0, 0, 0, 1, JobArtifactSet.emptyJobArtifactSet, serializedTaskMetrics = Array.empty) { override def runTask(context: TaskContext): Int = { throw new TaskContextSuite.FakeTaskFailureException @@ -430,7 +432,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark // Create a dummy task. We won't end up running this; we just want to collect // accumulator updates from it. val taskMetrics = TaskMetrics.empty - val task = new Task[Int](0, 0, 0, 1, JobArtifactSet(sc)) { + val task = new Task[Int](0, 0, 0, 1, JobArtifactSet.getActiveOrDefault(sc)) { context = new TaskContextImpl(0, 0, 0, 0L, 0, 1, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), new Properties, @@ -453,7 +455,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark // Create a dummy task. We won't end up running this; we just want to collect // accumulator updates from it. val taskMetrics = TaskMetrics.registered - val task = new Task[Int](0, 0, 0, 1, JobArtifactSet(sc)) { + val task = new Task[Int](0, 0, 0, 1, JobArtifactSet.getActiveOrDefault(sc)) { context = new TaskContextImpl(0, 0, 0, 0L, 0, 1, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), new Properties, diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala index 7f84806e1f87..b36363d0f4cd 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala @@ -66,8 +66,7 @@ class TaskDescriptionSuite extends SparkFunSuite { val taskBuffer = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4)) val artifacts = new JobArtifactSet( - uuid = None, - replClassDirUri = None, + None, jars = Map(originalJars.toSeq: _*), files = Map(originalFiles.toSeq: _*), archives = Map(originalArchives.toSeq: _*) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 2dd3b0fda203..c577348aef8e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -2155,11 +2155,13 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext new WorkerOffer("executor1", "host1", 1)) val task1 = new ShuffleMapTask(1, 0, null, new Partition { override def index: Int = 0 - }, 1, Seq(TaskLocation("host0", "executor0")), JobArtifactSet(sc), new Properties, null) + }, 1, Seq(TaskLocation("host0", "executor0")), + JobArtifactSet.getActiveOrDefault(sc), new Properties, null) val task2 = new ShuffleMapTask(1, 0, null, new Partition { override def index: Int = 1 - }, 1, Seq(TaskLocation("host1", "executor1")), JobArtifactSet(sc), new Properties, null) + }, 1, Seq(TaskLocation("host1", "executor1")), + JobArtifactSet.getActiveOrDefault(sc), new Properties, null) val taskSet = new TaskSet(Array(task1, task2), 0, 0, 0, null, 0, Some(0)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 10c1a72066fb..299ef5160599 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -184,7 +184,8 @@ class FakeTaskScheduler( /** * A Task implementation that results in a large serialized task. */ -class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, 1, JobArtifactSet()) { +class LargeTask(stageId: Int) extends Task[Array[Byte]]( + stageId, 0, 0, 1, JobArtifactSet.emptyJobArtifactSet) { val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KIB * 1024) val random = new Random(0) @@ -900,7 +901,8 @@ class TaskSetManagerSuite val singleTask = new ShuffleMapTask(0, 0, null, new Partition { override def index: Int = 0 - }, 1, Seq(TaskLocation("host1", "execA")), JobArtifactSet(sc), new Properties, null) + }, 1, Seq(TaskLocation("host1", "execA")), + JobArtifactSet.getActiveOrDefault(sc), new Properties, null) val taskSet = new TaskSet(Array(singleTask), 0, 0, 0, null, ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID, Some(0)) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) @@ -1519,7 +1521,7 @@ class TaskSetManagerSuite test("SPARK-21563 context's added jars shouldn't change mid-TaskSet") { sc = new SparkContext("local", "test") - val addedJarsPreTaskSet = Map[String, Long](sc.addedJars.toSeq: _*) + val addedJarsPreTaskSet = Map[String, Long](sc.allAddedJars.toSeq: _*) assert(addedJarsPreTaskSet.size === 0) sched = new FakeTaskScheduler(sc, ("exec1", "host1")) @@ -1535,7 +1537,7 @@ class TaskSetManagerSuite // even with a jar added mid-TaskSet val jarPath = Thread.currentThread().getContextClassLoader.getResource("TestUDTF.jar") sc.addJar(jarPath.toString) - val addedJarsMidTaskSet = Map[String, Long](sc.addedJars.toSeq: _*) + val addedJarsMidTaskSet = Map[String, Long](sc.allAddedJars.toSeq: _*) assert(addedJarsPreTaskSet !== addedJarsMidTaskSet) val taskOption3 = manager1.resourceOffer("exec1", "host1", NO_PREF)._1 // which should have the old version of the jars list diff --git a/python/pyspark/sql/tests/connect/client/test_artifact.py b/python/pyspark/sql/tests/connect/client/test_artifact.py index ab285a2b862e..cbd00acf8290 100644 --- a/python/pyspark/sql/tests/connect/client/test_artifact.py +++ b/python/pyspark/sql/tests/connect/client/test_artifact.py @@ -20,6 +20,7 @@ import unittest import os +from pyspark.sql import SparkSession from pyspark.testing.connectutils import ReusedConnectTestCase, should_test_connect from pyspark.testing.utils import SPARK_HOME from pyspark import SparkFiles @@ -27,6 +28,7 @@ if should_test_connect: from pyspark.sql.connect.client.artifact import ArtifactManager + from pyspark.sql.connect.client import ChannelBuilder class ArtifactTests(ReusedConnectTestCase): @@ -228,7 +230,7 @@ def test_single_chunked_and_chunked_artifact(self): self.assertEqual(artifact2.data.crc, crc) self.assertEqual(artifact2.data.data, data) - def test_add_pyfile(self): + def check_add_pyfile(self, spark_session): with tempfile.TemporaryDirectory() as d: pyfile_path = os.path.join(d, "my_pyfile.py") with open(pyfile_path, "w") as f: @@ -240,10 +242,19 @@ def func(x): return my_pyfile.my_func() - self.spark.addArtifacts(pyfile_path, pyfile=True) - self.assertEqual(self.spark.range(1).select(func("id")).first()[0], 10) + spark_session.addArtifacts(pyfile_path, pyfile=True) + self.assertEqual(spark_session.range(1).select(func("id")).first()[0], 10) - def test_add_zipped_package(self): + def test_add_pyfile(self): + self.check_add_pyfile(self.spark) + + # Test multi sessions. Should be able to add the same + # file from different session. + self.check_add_pyfile( + SparkSession.builder.remote(f"sc://localhost:{ChannelBuilder.default_port()}").create() + ) + + def check_add_zipped_package(self, spark_session): with tempfile.TemporaryDirectory() as d: package_path = os.path.join(d, "my_zipfile") os.mkdir(package_path) @@ -258,10 +269,19 @@ def func(x): return my_zipfile.my_func() - self.spark.addArtifacts(f"{package_path}.zip", pyfile=True) - self.assertEqual(self.spark.range(1).select(func("id")).first()[0], 5) + spark_session.addArtifacts(f"{package_path}.zip", pyfile=True) + self.assertEqual(spark_session.range(1).select(func("id")).first()[0], 5) - def test_add_archive(self): + def test_add_zipped_package(self): + self.check_add_zipped_package(self.spark) + + # Test multi sessions. Should be able to add the same + # file from different session. + self.check_add_zipped_package( + SparkSession.builder.remote(f"sc://localhost:{ChannelBuilder.default_port()}").create() + ) + + def check_add_archive(self, spark_session): with tempfile.TemporaryDirectory() as d: archive_path = os.path.join(d, "my_archive") os.mkdir(archive_path) @@ -280,10 +300,19 @@ def func(x): ) as my_file: return my_file.read().strip() - self.spark.addArtifacts(f"{archive_path}.zip#my_files", archive=True) - self.assertEqual(self.spark.range(1).select(func("id")).first()[0], "hello world!") + spark_session.addArtifacts(f"{archive_path}.zip#my_files", archive=True) + self.assertEqual(spark_session.range(1).select(func("id")).first()[0], "hello world!") - def test_add_file(self): + def test_add_archive(self): + self.check_add_archive(self.spark) + + # Test multi sessions. Should be able to add the same + # file from different session. + self.check_add_archive( + SparkSession.builder.remote(f"sc://localhost:{ChannelBuilder.default_port()}").create() + ) + + def check_add_file(self, spark_session): with tempfile.TemporaryDirectory() as d: file_path = os.path.join(d, "my_file.txt") with open(file_path, "w") as f: @@ -296,8 +325,17 @@ def func(x): ) as my_file: return my_file.read().strip() - self.spark.addArtifacts(file_path, file=True) - self.assertEqual(self.spark.range(1).select(func("id")).first()[0], "Hello world!!") + spark_session.addArtifacts(file_path, file=True) + self.assertEqual(spark_session.range(1).select(func("id")).first()[0], "Hello world!!") + + def test_add_file(self): + self.check_add_file(self.spark) + + # Test multi sessions. Should be able to add the same + # file from different session. + self.check_add_file( + SparkSession.builder.remote(f"sc://localhost:{ChannelBuilder.default_port()}").create() + ) def test_copy_from_local_to_fs(self): with tempfile.TemporaryDirectory() as d: diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala index 114b667e6a4c..c44f22faa419 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala @@ -262,7 +262,7 @@ class MesosFineGrainedSchedulerBackendSuite name = "n1", index = 0, partitionId = 0, - artifacts = JobArtifactSet(), + artifacts = JobArtifactSet.emptyJobArtifactSet, properties = new Properties(), cpus = 1, resources = immutable.Map.empty[String, ResourceInformation], @@ -375,7 +375,7 @@ class MesosFineGrainedSchedulerBackendSuite name = "n1", index = 0, partitionId = 0, - artifacts = JobArtifactSet(), + artifacts = JobArtifactSet.emptyJobArtifactSet, properties = new Properties(), cpus = 1, resources = immutable.Map.empty[String, ResourceInformation], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 177a25b45fc3..2ab9d3c525cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -176,7 +176,9 @@ class SessionResourceLoader(session: SparkSession) extends FunctionResourceLoade * to add the jar to its hive client for the current session. Hence, it still needs to be in * [[SessionState]]. */ - def addJar(path: String): Unit = { + def addJar(path: String): Unit = addJar(path: String, sessionId = "default") + + private[spark] def addJar(path: String, sessionId: String): Unit = { val uri = Utils.resolveURI(path) resolveJars(uri).foreach { p => session.sparkContext.addJar(p) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 692c2215fdee..86d54ac967b7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -3276,8 +3276,8 @@ class HiveDDLSuite val jarName = "TestUDTF.jar" val jar = spark.asInstanceOf[TestHiveSparkSession].getHiveFile(jarName).toURI.toString - spark.sparkContext.addedJars.keys.find(_.contains(jarName)) - .foreach(spark.sparkContext.addedJars.remove) + spark.sparkContext.allAddedJars.keys.find(_.contains(jarName)) + .foreach(spark.sparkContext.addedJars("default").remove) assert(!spark.sparkContext.listJars().exists(_.contains(jarName))) val e = intercept[AnalysisException] { sql("CREATE TEMPORARY FUNCTION f1 AS " +