Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@
package org.apache.spark.sql.connect.artifact

import java.io.File
import java.net.{URL, URLClassLoader}
import java.net.{URI, URL, URLClassLoader}
import java.nio.file.{Files, Path, Paths, StandardCopyOption}
import java.util.concurrent.CopyOnWriteArrayList
import javax.ws.rs.core.UriBuilder

import scala.collection.JavaConverters._
import scala.reflect.ClassTag

import org.apache.commons.io.FileUtils
import org.apache.commons.io.{FilenameUtils, FileUtils}
import org.apache.hadoop.fs.{LocalFileSystem, Path => FSPath}

import org.apache.spark.{JobArtifactState, SparkContext, SparkEnv}
import org.apache.spark.{JobArtifactSet, JobArtifactState, SparkContext, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.artifact.util.ArtifactUtils
Expand Down Expand Up @@ -92,7 +92,7 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging
private[connect] def addArtifact(
remoteRelativePath: Path,
serverLocalStagingPath: Path,
fragment: Option[String]): Unit = {
fragment: Option[String]): Unit = JobArtifactSet.withActiveJobArtifactState(state) {
require(!remoteRelativePath.isAbsolute)
if (remoteRelativePath.startsWith(s"cache${File.separator}")) {
val tmpFile = serverLocalStagingPath.toFile
Expand Down Expand Up @@ -131,38 +131,27 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging
"Artifacts cannot be overwritten.")
}
Files.move(serverLocalStagingPath, target)

// This URI is for Spark file server that starts with "spark://".
val uri = s"$artifactURI/${Utils.encodeRelativeUnixPathToURIRawPath(
FilenameUtils.separatorsToUnix(remoteRelativePath.toString))}"

if (remoteRelativePath.startsWith(s"jars${File.separator}")) {
sessionHolder.session.sessionState.resourceLoader
.addJar(target.toString, state.uuid)
sessionHolder.session.sparkContext.addJar(uri)
jarsList.add(target)
} else if (remoteRelativePath.startsWith(s"pyfiles${File.separator}")) {
sessionHolder.session.sparkContext.addFile(
target.toString,
recursive = false,
addedOnSubmit = false,
isArchive = false,
sessionUUID = state.uuid)
sessionHolder.session.sparkContext.addFile(uri)
val stringRemotePath = remoteRelativePath.toString
if (stringRemotePath.endsWith(".zip") || stringRemotePath.endsWith(
".egg") || stringRemotePath.endsWith(".jar")) {
pythonIncludeList.add(target.getFileName.toString)
}
} else if (remoteRelativePath.startsWith(s"archives${File.separator}")) {
val canonicalUri =
fragment.map(UriBuilder.fromUri(target.toUri).fragment).getOrElse(target.toUri)
sessionHolder.session.sparkContext.addFile(
canonicalUri.toString,
recursive = false,
addedOnSubmit = false,
isArchive = true,
sessionUUID = state.uuid)
fragment.map(UriBuilder.fromUri(new URI(uri)).fragment).getOrElse(new URI(uri))
sessionHolder.session.sparkContext.addArchive(canonicalUri.toString)
} else if (remoteRelativePath.startsWith(s"files${File.separator}")) {
sessionHolder.session.sparkContext.addFile(
target.toString,
recursive = false,
addedOnSubmit = false,
isArchive = false,
sessionUUID = state.uuid)
sessionHolder.session.sparkContext.addFile(uri)
}
}
}
Expand Down
5 changes: 4 additions & 1 deletion core/src/main/scala/org/apache/spark/JobArtifactSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,15 @@ private[spark] object JobArtifactSet {
// For testing.
def defaultJobArtifactSet: JobArtifactSet = SparkContext.getActive.map(
getActiveOrDefault).getOrElse(emptyJobArtifactSet)
// For testing
var lastSeenState: Option[JobArtifactState] = None

private[this] val currentClientSessionState: ThreadLocal[Option[JobArtifactState]] =
new ThreadLocal[Option[JobArtifactState]] {
override def initialValue(): Option[JobArtifactState] = None
}

def getCurrentClientSessionState: Option[JobArtifactState] = currentClientSessionState.get()
def getCurrentJobArtifactState: Option[JobArtifactState] = currentClientSessionState.get()

/**
* Set the Spark Connect specific information in the active client to the underlying
Expand All @@ -88,6 +90,7 @@ private[spark] object JobArtifactSet {
def withActiveJobArtifactState[T](state: JobArtifactState)(block: => T): T = {
val oldState = currentClientSessionState.get()
currentClientSessionState.set(Option(state))
lastSeenState = Option(state)
try block finally {
currentClientSessionState.set(oldState)
}
Expand Down
42 changes: 26 additions & 16 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1733,13 +1733,11 @@ class SparkContext(config: SparkConf) extends Logging {
addFile(path, recursive, false)
}

private[spark] def addFile(
path: String,
recursive: Boolean,
addedOnSubmit: Boolean,
isArchive: Boolean = false,
sessionUUID: String = "default"
private def addFile(
path: String, recursive: Boolean, addedOnSubmit: Boolean, isArchive: Boolean = false
): Unit = {
val jobArtifactUUID = JobArtifactSet
.getCurrentJobArtifactState.map(_.uuid).getOrElse("default")
val uri = Utils.resolveURI(path)
val schemeCorrectedURI = uri.getScheme match {
case null => new File(path).getCanonicalFile.toURI
Expand All @@ -1752,7 +1750,7 @@ class SparkContext(config: SparkConf) extends Logging {

val hadoopPath = new Path(schemeCorrectedURI)
val scheme = schemeCorrectedURI.getScheme
if (!Array("http", "https", "ftp").contains(scheme) && !isArchive) {
if (!Array("http", "https", "ftp", "spark").contains(scheme) && !isArchive) {
val fs = hadoopPath.getFileSystem(hadoopConfiguration)
val isDir = fs.getFileStatus(hadoopPath).isDirectory
if (!isLocal && scheme == "file" && isDir) {
Expand All @@ -1775,21 +1773,31 @@ class SparkContext(config: SparkConf) extends Logging {
}

val timestamp = if (addedOnSubmit) startTime else System.currentTimeMillis
// If the session ID was specified from SparkSession, it's from a Spark Connect client.
// Specify a dedicated directory for Spark Connect client.
// We're running Spark Connect as a service so regular PySpark path
// is not affected.
lazy val root = if (jobArtifactUUID != "default") {
val newDest = new File(SparkFiles.getRootDirectory(), jobArtifactUUID)
Comment on lines +1776 to +1781
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed because the session-specific handling is more generic now?
Because for the JARs from Spark Connect, we preiovusly just registered the root artifact directory in the file server and built URIs that let the executor fetch the file directly without the need of copying over to the generic Spark Files directory.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it now needs to reuse PythonWorkerFactory in which assumes that there is a UUID named directory under SparkFiles.getRootDirectory() at both Driver and Executor. We could try to reuse the local artifact directory but I would prefer to have another copy in the local for now for better maintainability and reusability for now.

Otherwise, it does upload to the Spark file server twice (as we discussed offline). I pushed new changes to avoid this. So, after this change, we do not upload twice anymore by:

  1. Directly pass the spark:// URI to addFile and addJar
  2. addFile and addJar will not attempt to upload the files, but bypass the original URI.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reuse PythonWorkerFactory in which assumes that there is a UUID named directory under SparkFiles.getRootDirectory() at both Driver and Executor

Ahh gotcha, I am not very well aware of the Python side, good to know 👍

So, after this change, we do not upload twice anymore by:
Directly pass the spark:// URI to addFile and addJar
addFile and addJar will not attempt to upload the files, but bypass the original URI.

Awesome!

newDest.mkdir()
newDest
} else {
new File(SparkFiles.getRootDirectory())
}
if (
!isArchive &&
addedFiles
.getOrElseUpdate(sessionUUID, new ConcurrentHashMap[String, Long]().asScala)
.getOrElseUpdate(jobArtifactUUID, new ConcurrentHashMap[String, Long]().asScala)
.putIfAbsent(key, timestamp).isEmpty) {
logInfo(s"Added file $path at $key with timestamp $timestamp")
// Fetch the file locally so that closures which are run on the driver can still use the
// SparkFiles API to access files.
Utils.fetchFile(uri.toString, new File(SparkFiles.getRootDirectory()), conf,
hadoopConfiguration, timestamp, useCache = false)
Utils.fetchFile(uri.toString, root, conf, hadoopConfiguration, timestamp, useCache = false)
postEnvironmentUpdate()
} else if (
isArchive &&
addedArchives
.getOrElseUpdate(sessionUUID, new ConcurrentHashMap[String, Long]().asScala)
.getOrElseUpdate(jobArtifactUUID, new ConcurrentHashMap[String, Long]().asScala)
.putIfAbsent(
UriBuilder.fromUri(new URI(key)).fragment(uri.getFragment).build().toString,
timestamp).isEmpty) {
Expand All @@ -1800,7 +1808,7 @@ class SparkContext(config: SparkConf) extends Logging {
val source = Utils.fetchFile(uriToDownload.toString, Utils.createTempDir(), conf,
hadoopConfiguration, timestamp, useCache = false, shouldUntar = false)
val dest = new File(
SparkFiles.getRootDirectory(),
root,
if (uri.getFragment != null) uri.getFragment else source.getName)
logInfo(
s"Unpacking an archive $path from ${source.getAbsolutePath} to ${dest.getAbsolutePath}")
Expand Down Expand Up @@ -2083,8 +2091,9 @@ class SparkContext(config: SparkConf) extends Logging {
addJar(path, false)
}

private[spark] def addJar(
path: String, addedOnSubmit: Boolean, sessionUUID: String = "default"): Unit = {
private def addJar(path: String, addedOnSubmit: Boolean): Unit = {
val jobArtifactUUID = JobArtifactSet
.getCurrentJobArtifactState.map(_.uuid).getOrElse("default")
def addLocalJarFile(file: File): Seq[String] = {
try {
if (!file.exists()) {
Expand All @@ -2094,6 +2103,7 @@ class SparkContext(config: SparkConf) extends Logging {
throw new IllegalArgumentException(
s"Directory ${file.getAbsoluteFile} is not allowed for addJar")
}

Seq(env.rpcEnv.fileServer.addJar(file))
} catch {
case NonFatal(e) =>
Expand All @@ -2105,7 +2115,7 @@ class SparkContext(config: SparkConf) extends Logging {
def checkRemoteJarFile(path: String): Seq[String] = {
val hadoopPath = new Path(path)
val scheme = hadoopPath.toUri.getScheme
if (!Array("http", "https", "ftp").contains(scheme)) {
if (!Array("http", "https", "ftp", "spark").contains(scheme)) {
try {
val fs = hadoopPath.getFileSystem(hadoopConfiguration)
if (!fs.exists(hadoopPath)) {
Expand Down Expand Up @@ -2158,7 +2168,7 @@ class SparkContext(config: SparkConf) extends Logging {
if (keys.nonEmpty) {
val timestamp = if (addedOnSubmit) startTime else System.currentTimeMillis
val (added, existed) = keys.partition(addedJars
.getOrElseUpdate(sessionUUID, new ConcurrentHashMap[String, Long]().asScala)
.getOrElseUpdate(jobArtifactUUID, new ConcurrentHashMap[String, Long]().asScala)
.putIfAbsent(_, timestamp).isEmpty)
if (added.nonEmpty) {
val jarMessage = if (scheme != "ivy") "JAR" else "dependency jars of Ivy URI"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ private[spark] class PythonRDD(
isFromBarrier: Boolean = false)
extends RDD[Array[Byte]](parent) {

private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)

override def getPartitions: Array[Partition] = firstParent.partitions

override val partitioner: Option[Partitioner] = {
Expand All @@ -61,7 +63,7 @@ private[spark] class PythonRDD(
val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)

override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val runner = PythonRunner(func)
val runner = PythonRunner(func, jobArtifactUUID)
runner.compute(firstParent.iterator(split, context), split.index, context)
}

Expand Down
20 changes: 12 additions & 8 deletions core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ private object BasePythonRunner {
private[spark] abstract class BasePythonRunner[IN, OUT](
protected val funcs: Seq[ChainedPythonFunctions],
protected val evalType: Int,
protected val argOffsets: Array[Array[Int]])
protected val argOffsets: Array[Array[Int]],
protected val jobArtifactUUID: Option[String])
extends Logging {

require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs")
Expand Down Expand Up @@ -165,8 +166,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
envVars.put("PYTHON_FAULTHANDLER_DIR", BasePythonRunner.faultHandlerLogDir.toString)
}

val sessionUUID = JobArtifactSet.getCurrentClientSessionState.map(_.uuid).getOrElse("default")
envVars.put("SPARK_CONNECT_SESSION_UUID", sessionUUID)
envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default"))

val (worker: Socket, pid: Option[Int]) = env.createPythonWorker(
pythonExec, envVars.asScala.toMap)
Expand Down Expand Up @@ -381,7 +381,10 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
}

// sparkFilesDir
PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut)
val root = jobArtifactUUID.map { uuid =>
new File(SparkFiles.getRootDirectory(), uuid).getAbsolutePath
}.getOrElse(SparkFiles.getRootDirectory())
PythonRDD.writeUTF(root, dataOut)
// Python includes (*.zip and *.egg files)
dataOut.writeInt(pythonIncludes.size)
for (include <- pythonIncludes) {
Expand Down Expand Up @@ -712,20 +715,21 @@ private[spark] object PythonRunner {

private var printPythonInfo: AtomicBoolean = new AtomicBoolean(true)

def apply(func: PythonFunction): PythonRunner = {
def apply(func: PythonFunction, jobArtifactUUID: Option[String]): PythonRunner = {
if (printPythonInfo.compareAndSet(true, false)) {
PythonUtils.logPythonInfo(func.pythonExec)
}
new PythonRunner(Seq(ChainedPythonFunctions(Seq(func))))
new PythonRunner(Seq(ChainedPythonFunctions(Seq(func))), jobArtifactUUID)
}
}

/**
* A helper class to run Python mapPartition in Spark.
*/
private[spark] class PythonRunner(funcs: Seq[ChainedPythonFunctions])
private[spark] class PythonRunner(
funcs: Seq[ChainedPythonFunctions], jobArtifactUUID: Option[String])
extends BasePythonRunner[Array[Byte], Array[Byte]](
funcs, PythonEvalType.NON_UDF, Array(Array(0))) {
funcs, PythonEvalType.NON_UDF, Array(Array(0)), jobArtifactUUID) {

protected override def newWriterThread(
env: SparkEnv,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,11 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String

// Create and start the worker
val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", workerModule))
val sessionId = envVars.getOrElse("SPARK_CONNECT_SESSION_UUID", "default")
if (sessionId != "default") {
pb.directory(new File(SparkFiles.getRootDirectory(), sessionId))
val jobArtifactUUID = envVars.getOrElse("SPARK_JOB_ARTIFACT_UUID", "default")
if (jobArtifactUUID != "default") {
val f = new File(SparkFiles.getRootDirectory(), jobArtifactUUID)
f.mkdir()
pb.directory(f)
}
val workerEnv = pb.environment()
workerEnv.putAll(envVars.asJava)
Expand Down Expand Up @@ -214,9 +216,11 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
// Create and start the daemon
val command = Arrays.asList(pythonExec, "-m", daemonModule)
val pb = new ProcessBuilder(command)
val sessionId = envVars.getOrElse("SPARK_CONNECT_SESSION_UUID", "default")
if (sessionId != "default") {
pb.directory(new File(SparkFiles.getRootDirectory(), sessionId))
val jobArtifactUUID = envVars.getOrElse("SPARK_JOB_ARTIFACT_UUID", "default")
if (jobArtifactUUID != "default") {
val f = new File(SparkFiles.getRootDirectory(), jobArtifactUUID)
f.mkdir()
pb.directory(f)
}
val workerEnv = pb.environment()
workerEnv.putAll(envVars.asJava)
Expand Down
13 changes: 12 additions & 1 deletion core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,18 @@ private[spark] object Utils extends Logging with SparkClassUtils {
// `file` and `localhost` are not used. Just to prevent URI from parsing `fileName` as
// scheme or host. The prefix "/" is required because URI doesn't accept a relative path.
// We should remove it after we get the raw path.
new URI("file", null, "localhost", -1, "/" + fileName, null, null).getRawPath.substring(1)
encodeRelativeUnixPathToURIRawPath(fileName).substring(1)
}

/**
* Same as [[encodeFileNameToURIRawPath]] but returns the relative UNIX path.
*/
def encodeRelativeUnixPathToURIRawPath(path: String): String = {
require(!path.startsWith("/") && !path.contains("\\"))
// `file` and `localhost` are not used. Just to prevent URI from parsing `fileName` as
// scheme or host. The prefix "/" is required because URI doesn't accept a relative path.
// We should remove it after we get the raw path.
new URI("file", null, "localhost", -1, "/" + path, null, null).getRawPath
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ class ClassLoaderIsolationSuite extends SparkFunSuite with LocalSparkContext {
files = Map.empty,
archives = Map.empty
)
sc.addJar(jar2, false, artifactSetWithHelloV2.state.get.uuid)

JobArtifactSet.withActiveJobArtifactState(artifactSetWithHelloV2.state.get) {
sc.addJar(jar2)
sc.parallelize(1 to 1).foreach { i =>
val cls = Utils.classForName("com.example.Hello$")
val module = cls.getField("MODULE$").get(null)
Expand All @@ -76,9 +76,9 @@ class ClassLoaderIsolationSuite extends SparkFunSuite with LocalSparkContext {
files = Map.empty,
archives = Map.empty
)
sc.addJar(jar3, false, artifactSetWithHelloV3.state.get.uuid)

JobArtifactSet.withActiveJobArtifactState(artifactSetWithHelloV3.state.get) {
sc.addJar(jar3)
sc.parallelize(1 to 1).foreach { i =>
val cls = Utils.classForName("com.example.Hello$")
val module = cls.getField("MODULE$").get(null)
Expand All @@ -96,9 +96,9 @@ class ClassLoaderIsolationSuite extends SparkFunSuite with LocalSparkContext {
files = Map.empty,
archives = Map.empty
)
sc.addJar(jar1, false, artifactSetWithoutHello.state.get.uuid)

JobArtifactSet.withActiveJobArtifactState(artifactSetWithoutHello.state.get) {
sc.addJar(jar1)
sc.parallelize(1 to 1).foreach { i =>
try {
Utils.classForName("com.example.Hello$")
Expand Down
10 changes: 10 additions & 0 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,16 @@ def create_conf(**kwargs: Any) -> SparkConf:
pyutils = SparkContext._jvm.PythonSQLUtils # type: ignore[union-attr]
pyutils.addJarToCurrentClassLoader(connect_jar)

# Required for local-cluster testing as their executors need the jars
# to load the Spark plugin for Spark Connect.
if master.startswith("local-cluster"):
if "spark.jars" in overwrite_conf:
overwrite_conf[
"spark.jars"
] = f"{overwrite_conf['spark.jars']},{connect_jar}"
else:
overwrite_conf["spark.jars"] = connect_jar

except ImportError:
pass

Expand Down
Loading