Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import scala.reflect.ClassTag
import org.apache.commons.io.FileUtils
import org.apache.hadoop.fs.{LocalFileSystem, Path => FSPath}

import org.apache.spark.{JobArtifactSet, SparkContext, SparkEnv}
import org.apache.spark.{JobArtifactState, SparkContext, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.artifact.util.ArtifactUtils
Expand All @@ -56,15 +56,15 @@ import org.apache.spark.util.Utils
class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging {
import SparkConnectArtifactManager._

private val sessionUUID = sessionHolder.session.sessionUUID
// The base directory/URI where all artifacts are stored for this `sessionUUID`.
val (artifactPath, artifactURI): (Path, String) =
getArtifactDirectoryAndUriForSession(sessionHolder)
// The base directory/URI where all class file artifacts are stored for this `sessionUUID`.
val (classDir, classURI): (Path, String) = getClassfileDirectoryAndUriForSession(sessionHolder)
val state: JobArtifactState =
JobArtifactState(sessionHolder.session.sessionUUID, Option(classURI))

private val jarsList = new CopyOnWriteArrayList[Path]
private val jarsURI = new CopyOnWriteArrayList[String]
private val pythonIncludeList = new CopyOnWriteArrayList[String]

/**
Expand Down Expand Up @@ -132,10 +132,16 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging
}
Files.move(serverLocalStagingPath, target)
if (remoteRelativePath.startsWith(s"jars${File.separator}")) {
sessionHolder.session.sessionState.resourceLoader
.addJar(target.toString, state.uuid)
jarsList.add(target)
jarsURI.add(artifactURI + "/" + remoteRelativePath.toString)
} else if (remoteRelativePath.startsWith(s"pyfiles${File.separator}")) {
sessionHolder.session.sparkContext.addFile(target.toString)
sessionHolder.session.sparkContext.addFile(
target.toString,
recursive = false,
addedOnSubmit = false,
isArchive = false,
sessionUUID = state.uuid)
val stringRemotePath = remoteRelativePath.toString
if (stringRemotePath.endsWith(".zip") || stringRemotePath.endsWith(
".egg") || stringRemotePath.endsWith(".jar")) {
Expand All @@ -144,35 +150,28 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging
} else if (remoteRelativePath.startsWith(s"archives${File.separator}")) {
val canonicalUri =
fragment.map(UriBuilder.fromUri(target.toUri).fragment).getOrElse(target.toUri)
sessionHolder.session.sparkContext.addArchive(canonicalUri.toString)
sessionHolder.session.sparkContext.addFile(
canonicalUri.toString,
recursive = false,
addedOnSubmit = false,
isArchive = true,
sessionUUID = state.uuid)
} else if (remoteRelativePath.startsWith(s"files${File.separator}")) {
sessionHolder.session.sparkContext.addFile(target.toString)
sessionHolder.session.sparkContext.addFile(
target.toString,
recursive = false,
addedOnSubmit = false,
isArchive = false,
sessionUUID = state.uuid)
}
}
}

/**
* Returns a [[JobArtifactSet]] pointing towards the session-specific jars and class files.
*/
def jobArtifactSet: JobArtifactSet = {
val builder = Map.newBuilder[String, Long]
jarsURI.forEach { jar =>
builder += jar -> 0
}

new JobArtifactSet(
uuid = Option(sessionUUID),
replClassDirUri = Option(classURI),
jars = builder.result(),
files = Map.empty,
archives = Map.empty)
}

/**
* Returns a [[ClassLoader]] for session-specific jar/class file resources.
*/
def classloader: ClassLoader = {
val urls = jarsList.asScala.map(_.toUri.toURL) :+ classDir.toUri.toURL
val urls = getSparkConnectAddedJars :+ classDir.toUri.toURL
new URLClassLoader(urls.toArray, Utils.getContextOrSparkClassLoader)
}

Expand All @@ -183,6 +182,12 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging
logDebug(
s"Cleaning up resources for session with userId: ${sessionHolder.userId} and " +
s"sessionId: ${sessionHolder.sessionId}")

// Clean up added files
sessionHolder.session.sparkContext.addedFiles.remove(state.uuid)
sessionHolder.session.sparkContext.addedArchives.remove(state.uuid)
sessionHolder.session.sparkContext.addedJars.remove(state.uuid)

// Clean up cached relations
val blockManager = sessionHolder.session.sparkContext.env.blockManager
blockManager.removeCache(sessionHolder.userId, sessionHolder.sessionId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,6 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
artifactManager.addArtifact(remoteRelativePath, serverLocalStagingPath, fragment)
}

/**
* A [[JobArtifactSet]] for this SparkConnect session.
*/
def connectJobArtifactSet: JobArtifactSet = artifactManager.jobArtifactSet

/**
* A [[ClassLoader]] for jar/class file resources specific to this SparkConnect session.
*/
Expand All @@ -114,8 +109,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
def withContextClassLoader[T](f: => T): T = {
// Needed for deserializing and evaluating the UDF on the driver
Utils.withContextClassLoader(classloader) {
// Needed for propagating the dependencies to the executors.
JobArtifactSet.withActive(connectJobArtifactSet) {
JobArtifactSet.withActiveJobArtifactState(artifactManager.state) {
f
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,6 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper {
super.afterEach()
}

test("Jar artifacts are added to spark session") {
val copyDir = Utils.createTempDir().toPath
FileUtils.copyDirectory(artifactPath.toFile, copyDir.toFile)
val stagingPath = copyDir.resolve("smallJar.jar")
val remotePath = Paths.get("jars/smallJar.jar")
artifactManager.addArtifact(remotePath, stagingPath, None)

val expectedPath = SparkConnectArtifactManager.artifactRootPath
.resolve(s"$sessionUUID/jars/smallJar.jar")
assert(expectedPath.toFile.exists())
val jars = artifactManager.jobArtifactSet.jars
assert(jars.exists(_._1.contains(remotePath.toString)))
}

test("Class artifacts are added to the correct directory.") {
val copyDir = Utils.createTempDir().toPath
FileUtils.copyDirectory(artifactPath.toFile, copyDir.toFile)
Expand Down
123 changes: 59 additions & 64 deletions core/src/main/scala/org/apache/spark/JobArtifactSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,105 +20,100 @@ package org.apache.spark
import java.io.Serializable
import java.util.Objects

/**
* Job artifact state. For example, Spark Connect client sets the state specifically
* for the current client.
*
* @param uuid UUID to use in the current context of jab artifact set. Usually this is from
* a Spark Connect client.
* @param replClassDirUri The URI for the directory that stores REPL classes.
*/
private[spark] case class JobArtifactState(uuid: String, replClassDirUri: Option[String])

/**
* Artifact set for a job.
* This class is used to store session (i.e `SparkSession`) specific resources/artifacts.
*
* When Spark Connect is used, this job-set points towards session-specific jars and class files.
* Note that Spark Connect is not a requirement for using this class.
*
* @param uuid An optional UUID for this session. If unset, a default session will be used.
* @param replClassDirUri An optional custom URI to point towards class files.
* @param state Job artifact state.
* @param jars Jars belonging to this session.
* @param files Files belonging to this session.
* @param archives Archives belonging to this session.
*/
class JobArtifactSet(
val uuid: Option[String],
val replClassDirUri: Option[String],
private[spark] class JobArtifactSet(
val state: Option[JobArtifactState],
val jars: Map[String, Long],
val files: Map[String, Long],
val archives: Map[String, Long]) extends Serializable {
def withActive[T](f: => T): T = JobArtifactSet.withActive(this)(f)

override def hashCode(): Int = {
Objects.hash(uuid, replClassDirUri, jars.toSeq, files.toSeq, archives.toSeq)
Objects.hash(state, jars.toSeq, files.toSeq, archives.toSeq)
}

override def equals(obj: Any): Boolean = {
obj match {
case that: JobArtifactSet =>
this.getClass == that.getClass && this.uuid == that.uuid &&
this.replClassDirUri == that.replClassDirUri && this.jars.toSeq == that.jars.toSeq &&
this.getClass == that.getClass && this.state == that.state &&
this.jars.toSeq == that.jars.toSeq &&
this.files.toSeq == that.files.toSeq && this.archives.toSeq == that.archives.toSeq
}
}

}

object JobArtifactSet {

private[this] val current = new ThreadLocal[Option[JobArtifactSet]] {
override def initialValue(): Option[JobArtifactSet] = None
}

/**
* When Spark Connect isn't used, we default back to the shared resources.
* @param sc The active [[SparkContext]]
* @return A [[JobArtifactSet]] containing a copy of the jars/files/archives from the underlying
* [[SparkContext]] `sc`.
*/
def apply(sc: SparkContext): JobArtifactSet = {
new JobArtifactSet(
uuid = None,
replClassDirUri = sc.conf.getOption("spark.repl.class.uri"),
jars = sc.addedJars.toMap,
files = sc.addedFiles.toMap,
archives = sc.addedArchives.toMap)
}

private lazy val emptyJobArtifactSet = new JobArtifactSet(
None,
None,
Map.empty,
Map.empty,
Map.empty)
private[spark] object JobArtifactSet {
// For testing.
val emptyJobArtifactSet: JobArtifactSet = new JobArtifactSet(
None, Map.empty, Map.empty, Map.empty)
// For testing.
def defaultJobArtifactSet: JobArtifactSet = SparkContext.getActive.map(
getActiveOrDefault).getOrElse(emptyJobArtifactSet)

/**
* Empty artifact set for use in tests.
*/
private[spark] def apply(): JobArtifactSet = emptyJobArtifactSet
private[this] val currentClientSessionState: ThreadLocal[Option[JobArtifactState]] =
new ThreadLocal[Option[JobArtifactState]] {
override def initialValue(): Option[JobArtifactState] = None
}

/**
* Used for testing. Returns artifacts from [[SparkContext]] if one exists or otherwise, an
* empty set.
*/
private[spark] def defaultArtifactSet(): JobArtifactSet = {
SparkContext.getActive.map(sc => JobArtifactSet(sc)).getOrElse(JobArtifactSet())
}
def getCurrentClientSessionState: Option[JobArtifactState] = currentClientSessionState.get()

/**
* Execute a block of code with the currently active [[JobArtifactSet]].
* @param active
* @param block
* @tparam T
* Set the Spark Connect specific information in the active client to the underlying
* [[JobArtifactSet]].
*
* @param state Job artifact state.
* @return the result from the function applied with [[JobArtifactSet]] specific to
* the active client.
*/
def withActive[T](active: JobArtifactSet)(block: => T): T = {
val old = current.get()
current.set(Option(active))
def withActiveJobArtifactState[T](state: JobArtifactState)(block: => T): T = {
val oldState = currentClientSessionState.get()
currentClientSessionState.set(Option(state))
try block finally {
current.set(old)
currentClientSessionState.set(oldState)
}
}

/**
* Optionally returns the active [[JobArtifactSet]].
*/
def active: Option[JobArtifactSet] = current.get()

/**
* Return the active [[JobArtifactSet]] or creates the default set using the [[SparkContext]].
* @param sc
* When Spark Connect isn't used, we default back to the shared resources.
*
* @param sc The active [[SparkContext]]
* @return A [[JobArtifactSet]] containing a copy of the jars/files/archives.
* If there is an active client, it sets the information from them.
* Otherwise, it falls back to the default in the [[SparkContext]].
*/
def getActiveOrDefault(sc: SparkContext): JobArtifactSet = active.getOrElse(JobArtifactSet(sc))
def getActiveOrDefault(sc: SparkContext): JobArtifactSet = {
val maybeState = currentClientSessionState.get().map(s => s.copy(
replClassDirUri = s.replClassDirUri.orElse(sc.conf.getOption("spark.repl.class.uri"))))
new JobArtifactSet(
state = maybeState,
jars = maybeState
.map(s => sc.addedJars.getOrElse(s.uuid, sc.allAddedJars))
.getOrElse(sc.allAddedJars).toMap,
files = maybeState
.map(s => sc.addedFiles.getOrElse(s.uuid, sc.allAddedFiles))
.getOrElse(sc.allAddedFiles).toMap,
archives = maybeState
.map(s => sc.addedArchives.getOrElse(s.uuid, sc.allAddedArchives))
.getOrElse(sc.allAddedArchives).toMap)
}
}
Loading