diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala index 1d784d813d10..d8f290639c2f 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.connect.artifact import java.io.File -import java.net.{URL, URLClassLoader} +import java.net.{URI, URL, URLClassLoader} import java.nio.file.{Files, Path, Paths, StandardCopyOption} import java.util.concurrent.CopyOnWriteArrayList import javax.ws.rs.core.UriBuilder @@ -26,10 +26,10 @@ import javax.ws.rs.core.UriBuilder import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import org.apache.commons.io.FileUtils +import org.apache.commons.io.{FilenameUtils, FileUtils} import org.apache.hadoop.fs.{LocalFileSystem, Path => FSPath} -import org.apache.spark.{JobArtifactState, SparkContext, SparkEnv} +import org.apache.spark.{JobArtifactSet, JobArtifactState, SparkContext, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connect.artifact.util.ArtifactUtils @@ -92,7 +92,7 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging private[connect] def addArtifact( remoteRelativePath: Path, serverLocalStagingPath: Path, - fragment: Option[String]): Unit = { + fragment: Option[String]): Unit = JobArtifactSet.withActiveJobArtifactState(state) { require(!remoteRelativePath.isAbsolute) if (remoteRelativePath.startsWith(s"cache${File.separator}")) { val tmpFile = serverLocalStagingPath.toFile @@ -131,17 +131,16 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging "Artifacts cannot be overwritten.") } Files.move(serverLocalStagingPath, target) + + // This URI is for Spark file server that starts with "spark://". + val uri = s"$artifactURI/${Utils.encodeRelativeUnixPathToURIRawPath( + FilenameUtils.separatorsToUnix(remoteRelativePath.toString))}" + if (remoteRelativePath.startsWith(s"jars${File.separator}")) { - sessionHolder.session.sessionState.resourceLoader - .addJar(target.toString, state.uuid) + sessionHolder.session.sparkContext.addJar(uri) jarsList.add(target) } else if (remoteRelativePath.startsWith(s"pyfiles${File.separator}")) { - sessionHolder.session.sparkContext.addFile( - target.toString, - recursive = false, - addedOnSubmit = false, - isArchive = false, - sessionUUID = state.uuid) + sessionHolder.session.sparkContext.addFile(uri) val stringRemotePath = remoteRelativePath.toString if (stringRemotePath.endsWith(".zip") || stringRemotePath.endsWith( ".egg") || stringRemotePath.endsWith(".jar")) { @@ -149,20 +148,10 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging } } else if (remoteRelativePath.startsWith(s"archives${File.separator}")) { val canonicalUri = - fragment.map(UriBuilder.fromUri(target.toUri).fragment).getOrElse(target.toUri) - sessionHolder.session.sparkContext.addFile( - canonicalUri.toString, - recursive = false, - addedOnSubmit = false, - isArchive = true, - sessionUUID = state.uuid) + fragment.map(UriBuilder.fromUri(new URI(uri)).fragment).getOrElse(new URI(uri)) + sessionHolder.session.sparkContext.addArchive(canonicalUri.toString) } else if (remoteRelativePath.startsWith(s"files${File.separator}")) { - sessionHolder.session.sparkContext.addFile( - target.toString, - recursive = false, - addedOnSubmit = false, - isArchive = false, - sessionUUID = state.uuid) + sessionHolder.session.sparkContext.addFile(uri) } } } diff --git a/core/src/main/scala/org/apache/spark/JobArtifactSet.scala b/core/src/main/scala/org/apache/spark/JobArtifactSet.scala index 7b6c18277bc4..54922f5783af 100644 --- a/core/src/main/scala/org/apache/spark/JobArtifactSet.scala +++ b/core/src/main/scala/org/apache/spark/JobArtifactSet.scala @@ -69,13 +69,15 @@ private[spark] object JobArtifactSet { // For testing. def defaultJobArtifactSet: JobArtifactSet = SparkContext.getActive.map( getActiveOrDefault).getOrElse(emptyJobArtifactSet) + // For testing + var lastSeenState: Option[JobArtifactState] = None private[this] val currentClientSessionState: ThreadLocal[Option[JobArtifactState]] = new ThreadLocal[Option[JobArtifactState]] { override def initialValue(): Option[JobArtifactState] = None } - def getCurrentClientSessionState: Option[JobArtifactState] = currentClientSessionState.get() + def getCurrentJobArtifactState: Option[JobArtifactState] = currentClientSessionState.get() /** * Set the Spark Connect specific information in the active client to the underlying @@ -88,6 +90,7 @@ private[spark] object JobArtifactSet { def withActiveJobArtifactState[T](state: JobArtifactState)(block: => T): T = { val oldState = currentClientSessionState.get() currentClientSessionState.set(Option(state)) + lastSeenState = Option(state) try block finally { currentClientSessionState.set(oldState) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 6f99afe316cf..80f7eaf00ed2 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1733,13 +1733,11 @@ class SparkContext(config: SparkConf) extends Logging { addFile(path, recursive, false) } - private[spark] def addFile( - path: String, - recursive: Boolean, - addedOnSubmit: Boolean, - isArchive: Boolean = false, - sessionUUID: String = "default" + private def addFile( + path: String, recursive: Boolean, addedOnSubmit: Boolean, isArchive: Boolean = false ): Unit = { + val jobArtifactUUID = JobArtifactSet + .getCurrentJobArtifactState.map(_.uuid).getOrElse("default") val uri = Utils.resolveURI(path) val schemeCorrectedURI = uri.getScheme match { case null => new File(path).getCanonicalFile.toURI @@ -1752,7 +1750,7 @@ class SparkContext(config: SparkConf) extends Logging { val hadoopPath = new Path(schemeCorrectedURI) val scheme = schemeCorrectedURI.getScheme - if (!Array("http", "https", "ftp").contains(scheme) && !isArchive) { + if (!Array("http", "https", "ftp", "spark").contains(scheme) && !isArchive) { val fs = hadoopPath.getFileSystem(hadoopConfiguration) val isDir = fs.getFileStatus(hadoopPath).isDirectory if (!isLocal && scheme == "file" && isDir) { @@ -1775,21 +1773,31 @@ class SparkContext(config: SparkConf) extends Logging { } val timestamp = if (addedOnSubmit) startTime else System.currentTimeMillis + // If the session ID was specified from SparkSession, it's from a Spark Connect client. + // Specify a dedicated directory for Spark Connect client. + // We're running Spark Connect as a service so regular PySpark path + // is not affected. + lazy val root = if (jobArtifactUUID != "default") { + val newDest = new File(SparkFiles.getRootDirectory(), jobArtifactUUID) + newDest.mkdir() + newDest + } else { + new File(SparkFiles.getRootDirectory()) + } if ( !isArchive && addedFiles - .getOrElseUpdate(sessionUUID, new ConcurrentHashMap[String, Long]().asScala) + .getOrElseUpdate(jobArtifactUUID, new ConcurrentHashMap[String, Long]().asScala) .putIfAbsent(key, timestamp).isEmpty) { logInfo(s"Added file $path at $key with timestamp $timestamp") // Fetch the file locally so that closures which are run on the driver can still use the // SparkFiles API to access files. - Utils.fetchFile(uri.toString, new File(SparkFiles.getRootDirectory()), conf, - hadoopConfiguration, timestamp, useCache = false) + Utils.fetchFile(uri.toString, root, conf, hadoopConfiguration, timestamp, useCache = false) postEnvironmentUpdate() } else if ( isArchive && addedArchives - .getOrElseUpdate(sessionUUID, new ConcurrentHashMap[String, Long]().asScala) + .getOrElseUpdate(jobArtifactUUID, new ConcurrentHashMap[String, Long]().asScala) .putIfAbsent( UriBuilder.fromUri(new URI(key)).fragment(uri.getFragment).build().toString, timestamp).isEmpty) { @@ -1800,7 +1808,7 @@ class SparkContext(config: SparkConf) extends Logging { val source = Utils.fetchFile(uriToDownload.toString, Utils.createTempDir(), conf, hadoopConfiguration, timestamp, useCache = false, shouldUntar = false) val dest = new File( - SparkFiles.getRootDirectory(), + root, if (uri.getFragment != null) uri.getFragment else source.getName) logInfo( s"Unpacking an archive $path from ${source.getAbsolutePath} to ${dest.getAbsolutePath}") @@ -2083,8 +2091,9 @@ class SparkContext(config: SparkConf) extends Logging { addJar(path, false) } - private[spark] def addJar( - path: String, addedOnSubmit: Boolean, sessionUUID: String = "default"): Unit = { + private def addJar(path: String, addedOnSubmit: Boolean): Unit = { + val jobArtifactUUID = JobArtifactSet + .getCurrentJobArtifactState.map(_.uuid).getOrElse("default") def addLocalJarFile(file: File): Seq[String] = { try { if (!file.exists()) { @@ -2094,6 +2103,7 @@ class SparkContext(config: SparkConf) extends Logging { throw new IllegalArgumentException( s"Directory ${file.getAbsoluteFile} is not allowed for addJar") } + Seq(env.rpcEnv.fileServer.addJar(file)) } catch { case NonFatal(e) => @@ -2105,7 +2115,7 @@ class SparkContext(config: SparkConf) extends Logging { def checkRemoteJarFile(path: String): Seq[String] = { val hadoopPath = new Path(path) val scheme = hadoopPath.toUri.getScheme - if (!Array("http", "https", "ftp").contains(scheme)) { + if (!Array("http", "https", "ftp", "spark").contains(scheme)) { try { val fs = hadoopPath.getFileSystem(hadoopConfiguration) if (!fs.exists(hadoopPath)) { @@ -2158,7 +2168,7 @@ class SparkContext(config: SparkConf) extends Logging { if (keys.nonEmpty) { val timestamp = if (addedOnSubmit) startTime else System.currentTimeMillis val (added, existed) = keys.partition(addedJars - .getOrElseUpdate(sessionUUID, new ConcurrentHashMap[String, Long]().asScala) + .getOrElseUpdate(jobArtifactUUID, new ConcurrentHashMap[String, Long]().asScala) .putIfAbsent(_, timestamp).isEmpty) if (added.nonEmpty) { val jarMessage = if (scheme != "ivy") "JAR" else "dependency jars of Ivy URI" diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 071ea50e9bd4..95fbc145d835 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -52,6 +52,8 @@ private[spark] class PythonRDD( isFromBarrier: Boolean = false) extends RDD[Array[Byte]](parent) { + private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + override def getPartitions: Array[Partition] = firstParent.partitions override val partitioner: Option[Partitioner] = { @@ -61,7 +63,7 @@ private[spark] class PythonRDD( val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { - val runner = PythonRunner(func) + val runner = PythonRunner(func, jobArtifactUUID) runner.compute(firstParent.iterator(split, context), split.index, context) } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 9d7b941db0af..902252084997 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -93,7 +93,8 @@ private object BasePythonRunner { private[spark] abstract class BasePythonRunner[IN, OUT]( protected val funcs: Seq[ChainedPythonFunctions], protected val evalType: Int, - protected val argOffsets: Array[Array[Int]]) + protected val argOffsets: Array[Array[Int]], + protected val jobArtifactUUID: Option[String]) extends Logging { require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") @@ -165,8 +166,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( envVars.put("PYTHON_FAULTHANDLER_DIR", BasePythonRunner.faultHandlerLogDir.toString) } - val sessionUUID = JobArtifactSet.getCurrentClientSessionState.map(_.uuid).getOrElse("default") - envVars.put("SPARK_CONNECT_SESSION_UUID", sessionUUID) + envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default")) val (worker: Socket, pid: Option[Int]) = env.createPythonWorker( pythonExec, envVars.asScala.toMap) @@ -381,7 +381,10 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( } // sparkFilesDir - PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) + val root = jobArtifactUUID.map { uuid => + new File(SparkFiles.getRootDirectory(), uuid).getAbsolutePath + }.getOrElse(SparkFiles.getRootDirectory()) + PythonRDD.writeUTF(root, dataOut) // Python includes (*.zip and *.egg files) dataOut.writeInt(pythonIncludes.size) for (include <- pythonIncludes) { @@ -712,20 +715,21 @@ private[spark] object PythonRunner { private var printPythonInfo: AtomicBoolean = new AtomicBoolean(true) - def apply(func: PythonFunction): PythonRunner = { + def apply(func: PythonFunction, jobArtifactUUID: Option[String]): PythonRunner = { if (printPythonInfo.compareAndSet(true, false)) { PythonUtils.logPythonInfo(func.pythonExec) } - new PythonRunner(Seq(ChainedPythonFunctions(Seq(func)))) + new PythonRunner(Seq(ChainedPythonFunctions(Seq(func))), jobArtifactUUID) } } /** * A helper class to run Python mapPartition in Spark. */ -private[spark] class PythonRunner(funcs: Seq[ChainedPythonFunctions]) +private[spark] class PythonRunner( + funcs: Seq[ChainedPythonFunctions], jobArtifactUUID: Option[String]) extends BasePythonRunner[Array[Byte], Array[Byte]]( - funcs, PythonEvalType.NON_UDF, Array(Array(0))) { + funcs, PythonEvalType.NON_UDF, Array(Array(0)), jobArtifactUUID) { protected override def newWriterThread( env: SparkEnv, diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 1f5c079f9994..19181bd98e11 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -157,9 +157,11 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String // Create and start the worker val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", workerModule)) - val sessionId = envVars.getOrElse("SPARK_CONNECT_SESSION_UUID", "default") - if (sessionId != "default") { - pb.directory(new File(SparkFiles.getRootDirectory(), sessionId)) + val jobArtifactUUID = envVars.getOrElse("SPARK_JOB_ARTIFACT_UUID", "default") + if (jobArtifactUUID != "default") { + val f = new File(SparkFiles.getRootDirectory(), jobArtifactUUID) + f.mkdir() + pb.directory(f) } val workerEnv = pb.environment() workerEnv.putAll(envVars.asJava) @@ -214,9 +216,11 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String // Create and start the daemon val command = Arrays.asList(pythonExec, "-m", daemonModule) val pb = new ProcessBuilder(command) - val sessionId = envVars.getOrElse("SPARK_CONNECT_SESSION_UUID", "default") - if (sessionId != "default") { - pb.directory(new File(SparkFiles.getRootDirectory(), sessionId)) + val jobArtifactUUID = envVars.getOrElse("SPARK_JOB_ARTIFACT_UUID", "default") + if (jobArtifactUUID != "default") { + val f = new File(SparkFiles.getRootDirectory(), jobArtifactUUID) + f.mkdir() + pb.directory(f) } val workerEnv = pb.environment() workerEnv.putAll(envVars.asJava) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index d4b2651748e8..83f677954aee 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -442,7 +442,18 @@ private[spark] object Utils extends Logging with SparkClassUtils { // `file` and `localhost` are not used. Just to prevent URI from parsing `fileName` as // scheme or host. The prefix "/" is required because URI doesn't accept a relative path. // We should remove it after we get the raw path. - new URI("file", null, "localhost", -1, "/" + fileName, null, null).getRawPath.substring(1) + encodeRelativeUnixPathToURIRawPath(fileName).substring(1) + } + + /** + * Same as [[encodeFileNameToURIRawPath]] but returns the relative UNIX path. + */ + def encodeRelativeUnixPathToURIRawPath(path: String): String = { + require(!path.startsWith("/") && !path.contains("\\")) + // `file` and `localhost` are not used. Just to prevent URI from parsing `fileName` as + // scheme or host. The prefix "/" is required because URI doesn't accept a relative path. + // We should remove it after we get the raw path. + new URI("file", null, "localhost", -1, "/" + path, null, null).getRawPath } /** diff --git a/core/src/test/scala/org/apache/spark/executor/ClassLoaderIsolationSuite.scala b/core/src/test/scala/org/apache/spark/executor/ClassLoaderIsolationSuite.scala index 479429e3d3c5..637b459886bc 100644 --- a/core/src/test/scala/org/apache/spark/executor/ClassLoaderIsolationSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ClassLoaderIsolationSuite.scala @@ -56,9 +56,9 @@ class ClassLoaderIsolationSuite extends SparkFunSuite with LocalSparkContext { files = Map.empty, archives = Map.empty ) - sc.addJar(jar2, false, artifactSetWithHelloV2.state.get.uuid) JobArtifactSet.withActiveJobArtifactState(artifactSetWithHelloV2.state.get) { + sc.addJar(jar2) sc.parallelize(1 to 1).foreach { i => val cls = Utils.classForName("com.example.Hello$") val module = cls.getField("MODULE$").get(null) @@ -76,9 +76,9 @@ class ClassLoaderIsolationSuite extends SparkFunSuite with LocalSparkContext { files = Map.empty, archives = Map.empty ) - sc.addJar(jar3, false, artifactSetWithHelloV3.state.get.uuid) JobArtifactSet.withActiveJobArtifactState(artifactSetWithHelloV3.state.get) { + sc.addJar(jar3) sc.parallelize(1 to 1).foreach { i => val cls = Utils.classForName("com.example.Hello$") val module = cls.getField("MODULE$").get(null) @@ -96,9 +96,9 @@ class ClassLoaderIsolationSuite extends SparkFunSuite with LocalSparkContext { files = Map.empty, archives = Map.empty ) - sc.addJar(jar1, false, artifactSetWithoutHello.state.get.uuid) JobArtifactSet.withActiveJobArtifactState(artifactSetWithoutHello.state.get) { + sc.addJar(jar1) sc.parallelize(1 to 1).foreach { i => try { Utils.classForName("com.example.Hello$") diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 358674c9189f..ea88d60d7608 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -754,6 +754,16 @@ def create_conf(**kwargs: Any) -> SparkConf: pyutils = SparkContext._jvm.PythonSQLUtils # type: ignore[union-attr] pyutils.addJarToCurrentClassLoader(connect_jar) + # Required for local-cluster testing as their executors need the jars + # to load the Spark plugin for Spark Connect. + if master.startswith("local-cluster"): + if "spark.jars" in overwrite_conf: + overwrite_conf[ + "spark.jars" + ] = f"{overwrite_conf['spark.jars']},{connect_jar}" + else: + overwrite_conf["spark.jars"] = connect_jar + except ImportError: pass diff --git a/python/pyspark/sql/tests/connect/client/test_artifact.py b/python/pyspark/sql/tests/connect/client/test_artifact.py index c685000b5ea4..4aec179bcf2a 100644 --- a/python/pyspark/sql/tests/connect/client/test_artifact.py +++ b/python/pyspark/sql/tests/connect/client/test_artifact.py @@ -31,7 +31,141 @@ from pyspark.sql.connect.client import ChannelBuilder -class ArtifactTests(ReusedConnectTestCase): +class ArtifactTestsMixin: + def check_add_pyfile(self, spark_session): + with tempfile.TemporaryDirectory() as d: + pyfile_path = os.path.join(d, "my_pyfile.py") + with open(pyfile_path, "w") as f: + f.write("my_func = lambda: 10") + + @udf("int") + def func(x): + import my_pyfile + + return my_pyfile.my_func() + + spark_session.addArtifacts(pyfile_path, pyfile=True) + self.assertEqual(spark_session.range(1).select(func("id")).first()[0], 10) + + def test_add_pyfile(self): + self.check_add_pyfile(self.spark) + + # Test multi sessions. Should be able to add the same + # file from different session. + self.check_add_pyfile( + SparkSession.builder.remote(f"sc://localhost:{ChannelBuilder.default_port()}").create() + ) + + def check_add_zipped_package(self, spark_session): + with tempfile.TemporaryDirectory() as d: + package_path = os.path.join(d, "my_zipfile") + os.mkdir(package_path) + pyfile_path = os.path.join(package_path, "__init__.py") + with open(pyfile_path, "w") as f: + _ = f.write("my_func = lambda: 5") + shutil.make_archive(package_path, "zip", d, "my_zipfile") + + @udf("long") + def func(x): + import my_zipfile + + return my_zipfile.my_func() + + spark_session.addArtifacts(f"{package_path}.zip", pyfile=True) + self.assertEqual(spark_session.range(1).select(func("id")).first()[0], 5) + + def test_add_zipped_package(self): + self.check_add_zipped_package(self.spark) + + # Test multi sessions. Should be able to add the same + # file from different session. + self.check_add_zipped_package( + SparkSession.builder.remote(f"sc://localhost:{ChannelBuilder.default_port()}").create() + ) + + def check_add_archive(self, spark_session): + with tempfile.TemporaryDirectory() as d: + archive_path = os.path.join(d, "my_archive") + os.mkdir(archive_path) + pyfile_path = os.path.join(archive_path, "my_file.txt") + with open(pyfile_path, "w") as f: + _ = f.write("hello world!") + shutil.make_archive(archive_path, "zip", d, "my_archive") + + # Should addArtifact first to make sure state is set, + # and 'root' can be found properly. + spark_session.addArtifacts(f"{archive_path}.zip#my_files", archive=True) + + root = self.root() + + @udf("string") + def func(x): + with open( + os.path.join(root, "my_files", "my_archive", "my_file.txt"), + "r", + ) as my_file: + return my_file.read().strip() + + self.assertEqual(spark_session.range(1).select(func("id")).first()[0], "hello world!") + + def test_add_archive(self): + self.check_add_archive(self.spark) + + # Test multi sessions. Should be able to add the same + # file from different session. + self.check_add_archive( + SparkSession.builder.remote(f"sc://localhost:{ChannelBuilder.default_port()}").create() + ) + + def check_add_file(self, spark_session): + with tempfile.TemporaryDirectory() as d: + file_path = os.path.join(d, "my_file.txt") + with open(file_path, "w") as f: + f.write("Hello world!!") + + # Should addArtifact first to make sure state is set, + # and 'root' can be found properly. + spark_session.addArtifacts(file_path, file=True) + + root = self.root() + + @udf("string") + def func(x): + with open(os.path.join(root, "my_file.txt"), "r") as my_file: + return my_file.read().strip() + + self.assertEqual(spark_session.range(1).select(func("id")).first()[0], "Hello world!!") + + def test_add_file(self): + self.check_add_file(self.spark) + + # Test multi sessions. Should be able to add the same + # file from different session. + self.check_add_file( + SparkSession.builder.remote(f"sc://localhost:{ChannelBuilder.default_port()}").create() + ) + + +class ArtifactTests(ReusedConnectTestCase, ArtifactTestsMixin): + @classmethod + def root(cls): + # In local mode, the file location is the same as Driver + # The executors are running in a thread. + jvm = SparkSession._instantiatedSession._jvm + current_uuid = ( + getattr( + getattr( + jvm.org.apache.spark, # type: ignore[union-attr] + "JobArtifactSet$", + ), + "MODULE$", + ) + .lastSeenState() + .get() + .uuid() + ) + return os.path.join(SparkFiles.getRootDirectory(), current_uuid) + @classmethod def setUpClass(cls): super(ArtifactTests, cls).setUpClass() @@ -230,117 +364,6 @@ def test_single_chunked_and_chunked_artifact(self): self.assertEqual(artifact2.data.crc, crc) self.assertEqual(artifact2.data.data, data) - def check_add_pyfile(self, spark_session): - with tempfile.TemporaryDirectory() as d: - pyfile_path = os.path.join(d, "my_pyfile.py") - with open(pyfile_path, "w") as f: - f.write("my_func = lambda: 10") - - @udf("long") - def func(x): - import my_pyfile - - return my_pyfile.my_func() - - spark_session.addArtifacts(pyfile_path, pyfile=True) - self.assertEqual(spark_session.range(1).select(func("id")).first()[0], 10) - - @unittest.skip("SPARK-44348: Reenable Session-based artifact test cases") - def test_add_pyfile(self): - self.check_add_pyfile(self.spark) - - # Test multi sessions. Should be able to add the same - # file from different session. - self.check_add_pyfile( - SparkSession.builder.remote(f"sc://localhost:{ChannelBuilder.default_port()}").create() - ) - - def check_add_zipped_package(self, spark_session): - with tempfile.TemporaryDirectory() as d: - package_path = os.path.join(d, "my_zipfile") - os.mkdir(package_path) - pyfile_path = os.path.join(package_path, "__init__.py") - with open(pyfile_path, "w") as f: - _ = f.write("my_func = lambda: 5") - shutil.make_archive(package_path, "zip", d, "my_zipfile") - - @udf("long") - def func(x): - import my_zipfile - - return my_zipfile.my_func() - - spark_session.addArtifacts(f"{package_path}.zip", pyfile=True) - self.assertEqual(spark_session.range(1).select(func("id")).first()[0], 5) - - @unittest.skip("SPARK-44348: Reenable Session-based artifact test cases") - def test_add_zipped_package(self): - self.check_add_zipped_package(self.spark) - - # Test multi sessions. Should be able to add the same - # file from different session. - self.check_add_zipped_package( - SparkSession.builder.remote(f"sc://localhost:{ChannelBuilder.default_port()}").create() - ) - - def check_add_archive(self, spark_session): - with tempfile.TemporaryDirectory() as d: - archive_path = os.path.join(d, "my_archive") - os.mkdir(archive_path) - pyfile_path = os.path.join(archive_path, "my_file.txt") - with open(pyfile_path, "w") as f: - _ = f.write("hello world!") - shutil.make_archive(archive_path, "zip", d, "my_archive") - - @udf("string") - def func(x): - with open( - os.path.join( - SparkFiles.getRootDirectory(), "my_files", "my_archive", "my_file.txt" - ), - "r", - ) as my_file: - return my_file.read().strip() - - spark_session.addArtifacts(f"{archive_path}.zip#my_files", archive=True) - self.assertEqual(spark_session.range(1).select(func("id")).first()[0], "hello world!") - - @unittest.skip("SPARK-44348: Reenable Session-based artifact test cases") - def test_add_archive(self): - self.check_add_archive(self.spark) - - # Test multi sessions. Should be able to add the same - # file from different session. - self.check_add_archive( - SparkSession.builder.remote(f"sc://localhost:{ChannelBuilder.default_port()}").create() - ) - - def check_add_file(self, spark_session): - with tempfile.TemporaryDirectory() as d: - file_path = os.path.join(d, "my_file.txt") - with open(file_path, "w") as f: - f.write("Hello world!!") - - @udf("string") - def func(x): - with open( - os.path.join(SparkFiles.getRootDirectory(), "my_file.txt"), "r" - ) as my_file: - return my_file.read().strip() - - spark_session.addArtifacts(file_path, file=True) - self.assertEqual(spark_session.range(1).select(func("id")).first()[0], "Hello world!!") - - @unittest.skip("SPARK-44348: Reenable Session-based artifact test cases") - def test_add_file(self): - self.check_add_file(self.spark) - - # Test multi sessions. Should be able to add the same - # file from different session. - self.check_add_file( - SparkSession.builder.remote(f"sc://localhost:{ChannelBuilder.default_port()}").create() - ) - def test_copy_from_local_to_fs(self): with tempfile.TemporaryDirectory() as d: with tempfile.TemporaryDirectory() as d2: @@ -366,6 +389,17 @@ def test_cache_artifact(self): self.assertEqual(self.artifact_manager.is_cached_artifact(expected_hash), True) +class LocalClusterArtifactTests(ReusedConnectTestCase, ArtifactTestsMixin): + @classmethod + def root(cls): + # In local cluster, we can mimic the production usage. + return "." + + @classmethod + def master(cls): + return "local-cluster[2,2,1024]" + + if __name__ == "__main__": from pyspark.sql.tests.connect.client.test_artifact import * # noqa: F401 diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index 68e08d5244f7..1b3ac10fce87 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -154,12 +154,16 @@ def conf(cls): conf.set("spark.sql.pyspark.jvmStacktrace.enabled", "false") return conf + @classmethod + def master(cls): + return "local[4]" + @classmethod def setUpClass(cls): cls.spark = ( PySparkSession.builder.config(conf=cls.conf()) .appName(cls.__name__) - .remote("local[4]") + .remote(cls.master()) .getOrCreate() ) cls.tempdir = tempfile.NamedTemporaryFile(delete=False) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index c51a3a5cce36..50452a5d9993 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -21,7 +21,7 @@ import java.io.File import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.{JobArtifactSet, SparkEnv, TaskContext} import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -127,6 +127,9 @@ case class AggregateInPandasExec( StructField(s"_$i", dt) }.toArray) + + val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + // Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else { // If we have session window expression in aggregation, we wrap iterator with @@ -170,7 +173,8 @@ case class AggregateInPandasExec( sessionLocalTimeZone, largeVarTypes, pythonRunnerConf, - pythonMetrics).compute(projectedRowIter, context.partitionId(), context) + pythonMetrics, + jobArtifactUUID).compute(projectedRowIter, context.partitionId(), context) val joinedAttributes = groupingExpressions.map(_.toAttribute) ++ aggExpressions.map(_.resultAttribute) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala index 35676406f141..9fc6ae04e94c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -61,8 +61,9 @@ class ApplyInPandasWithStatePythonRunner( keySchema: StructType, outputSchema: StructType, stateValueSchema: StructType, - val pythonMetrics: Map[String, SQLMetric]) - extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets) + val pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String]) + extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets, jobArtifactUUID) with PythonArrowInput[InType] with PythonArrowOutput[OutType] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 86a5d13aed05..2e25ee2ba74d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.python import scala.collection.JavaConverters._ -import org.apache.spark.TaskContext +import org.apache.spark.{JobArtifactSet, TaskContext} import org.apache.spark.api.python.ChainedPythonFunctions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -67,6 +67,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute] private val sessionLocalTimeZone = conf.sessionLocalTimeZone private val largeVarTypes = conf.arrowUseLargeVarTypes private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) protected override def evaluate( funcs: Seq[ChainedPythonFunctions], @@ -88,7 +89,8 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute] sessionLocalTimeZone, largeVarTypes, pythonRunnerConf, - pythonMetrics).compute(batchIter, context.partitionId(), context) + pythonMetrics, + jobArtifactUUID).compute(batchIter, context.partitionId(), context) columnarBatchIter.flatMap { batch => val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 175d67e90436..ea861df3c1f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -35,8 +35,10 @@ class ArrowPythonRunner( protected override val timeZoneId: String, protected override val largeVarTypes: Boolean, protected override val workerConf: Map[String, String], - val pythonMetrics: Map[String, SQLMetric]) - extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](funcs, evalType, argOffsets) + val pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String]) + extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( + funcs, evalType, argOffsets, jobArtifactUUID) with BasicPythonArrowInput with BasicPythonArrowOutput { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index c8a798d5b70c..71f1610bcec9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle.{Pickler, Unpickler} -import org.apache.spark.TaskContext +import org.apache.spark.{JobArtifactSet, TaskContext} import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -34,6 +34,8 @@ import org.apache.spark.sql.types.{StructField, StructType} case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute], child: SparkPlan) extends EvalPythonExec with PythonSQLMetrics { + private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + protected override def evaluate( funcs: Seq[ChainedPythonFunctions], argOffsets: Array[Array[Int]], @@ -47,7 +49,8 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute] // Output iterator for results from Python. val outputIterator = - new PythonUDFRunner(funcs, PythonEvalType.SQL_BATCHED_UDF, argOffsets, pythonMetrics) + new PythonUDFRunner( + funcs, PythonEvalType.SQL_BATCHED_UDF, argOffsets, pythonMetrics, jobArtifactUUID) .compute(inputIterator, context.partitionId(), context) val unpickle = new Unpickler diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala index b233f3983a70..feb3dfd0ba37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala @@ -25,7 +25,7 @@ import scala.collection.mutable.ArrayBuffer import net.razorvine.pickle.Unpickler -import org.apache.spark.{ContextAwareIterator, SparkEnv, TaskContext} +import org.apache.spark.{ContextAwareIterator, JobArtifactSet, SparkEnv, TaskContext} import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -52,6 +52,8 @@ case class BatchEvalPythonUDTFExec( child: SparkPlan) extends UnaryExecNode with PythonSQLMetrics { + private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + override def output: Seq[Attribute] = requiredChildOutput ++ resultAttrs override def producedAttributes: AttributeSet = AttributeSet(resultAttrs) @@ -145,7 +147,7 @@ case class BatchEvalPythonUDTFExec( // Output iterator for results from Python. val outputIterator = - new PythonUDTFRunner(udtf, argOffsets, pythonMetrics) + new PythonUDTFRunner(udtf, argOffsets, pythonMetrics, jobArtifactUUID) .compute(inputIterator, context.partitionId(), context) val unpickle = new Unpickler @@ -173,10 +175,11 @@ case class BatchEvalPythonUDTFExec( class PythonUDTFRunner( udtf: PythonUDTF, argOffsets: Array[Int], - pythonMetrics: Map[String, SQLMetric]) + pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String]) extends BasePythonUDFRunner( Seq(ChainedPythonFunctions(Seq(udtf.func))), - PythonEvalType.SQL_TABLE_UDF, Array(argOffsets), pythonMetrics) { + PythonEvalType.SQL_TABLE_UDF, Array(argOffsets), pythonMetrics, jobArtifactUUID) { protected override def newWriterThread( env: SparkEnv, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala index 5d91b1a11ce7..eef8be7c940b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala @@ -47,9 +47,11 @@ class CoGroupedArrowPythonRunner( rightSchema: StructType, timeZoneId: String, conf: Map[String, String], - val pythonMetrics: Map[String, SQLMetric]) + val pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String]) extends BasePythonRunner[ - (Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch](funcs, evalType, argOffsets) + (Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch]( + funcs, evalType, argOffsets, jobArtifactUUID) with BasicPythonArrowOutput { override val pythonExec: String = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala index 629df51e18ae..4da9fb8a60bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.python +import org.apache.spark.JobArtifactSet import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -79,6 +80,7 @@ case class FlatMapCoGroupsInPandasExec( override protected def doExecute(): RDD[InternalRow] = { val (leftDedup, leftArgOffsets) = resolveArgOffsets(left.output, leftGroup) val (rightDedup, rightArgOffsets) = resolveArgOffsets(right.output, rightGroup) + val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) // Map cogrouped rows to ArrowPythonRunner results, Only execute if partition is not empty left.execute().zipPartitions(right.execute()) { (leftData, rightData) => @@ -97,7 +99,8 @@ case class FlatMapCoGroupsInPandasExec( StructType.fromAttributes(rightDedup), sessionLocalTimeZone, pythonRunnerConf, - pythonMetrics) + pythonMetrics, + jobArtifactUUID) executePython(data, output, runner) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 8da53cc6c997..77385e9a7d31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.python +import org.apache.spark.JobArtifactSet import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -57,6 +58,7 @@ case class FlatMapGroupsInPandasExec( private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) private val pandasFunction = func.asInstanceOf[PythonUDF].func private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) + private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) override def producedAttributes: AttributeSet = AttributeSet(output) @@ -92,7 +94,8 @@ case class FlatMapGroupsInPandasExec( sessionLocalTimeZone, largeVarTypes, pythonRunnerConf, - pythonMetrics) + pythonMetrics, + jobArtifactUUID) executePython(data, output, runner) }} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala index 1c7ccaf4fe2a..8366c0c25ae4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.execution.python -import org.apache.spark.TaskContext +import org.apache.spark.{JobArtifactSet, TaskContext} import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow @@ -72,6 +72,7 @@ case class FlatMapGroupsInPandasWithStateExec( override protected val initialStateDataAttrs: Seq[Attribute] = null override protected val initialState: SparkPlan = null override protected val hasInitialState: Boolean = false + private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) override protected val stateEncoder: ExpressionEncoder[Any] = RowEncoder(stateType).resolveAndBind().asInstanceOf[ExpressionEncoder[Any]] @@ -176,7 +177,8 @@ case class FlatMapGroupsInPandasWithStateExec( groupingAttributes.toStructType, outAttributes.toStructType, stateType, - pythonMetrics) + pythonMetrics, + jobArtifactUUID) val context = TaskContext.get() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala index 8281435ca92d..b4af3db3c83a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.python import scala.collection.JavaConverters._ -import org.apache.spark.{ContextAwareIterator, TaskContext} +import org.apache.spark.{ContextAwareIterator, JobArtifactSet, TaskContext} import org.apache.spark.api.python.ChainedPythonFunctions import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -51,6 +51,8 @@ trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics { private val largeVarTypes = conf.arrowUseLargeVarTypes + private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + override def outputPartitioning: Partitioning = child.outputPartitioning override protected def doExecute(): RDD[InternalRow] = { @@ -81,7 +83,8 @@ trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics { sessionLocalTimeZone, largeVarTypes, pythonRunnerConf, - pythonMetrics).compute(batchIter, context.partitionId(), context) + pythonMetrics, + jobArtifactUUID).compute(batchIter, context.partitionId(), context) val unsafeProj = UnsafeProjection.create(output, output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala index 986943acede4..3857f084bcb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonForeachWriter.scala @@ -21,7 +21,7 @@ import java.io.File import java.util.concurrent.TimeUnit import java.util.concurrent.locks.ReentrantLock -import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.{JobArtifactSet, SparkEnv, TaskContext} import org.apache.spark.api.python._ import org.apache.spark.internal.Logging import org.apache.spark.memory.TaskMemoryManager @@ -39,6 +39,8 @@ class PythonForeachWriter(func: PythonFunction, schema: StructType) context.taskMemoryManager, new File(Utils.getLocalDir(SparkEnv.get.conf)), schema.fields.length) private lazy val inputRowIterator = buffer.iterator + private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + private lazy val inputByteIterator = { EvaluatePython.registerPicklers() val objIterator = inputRowIterator.map { row => EvaluatePython.toJava(row, schema) } @@ -46,7 +48,7 @@ class PythonForeachWriter(func: PythonFunction, schema: StructType) } private lazy val pythonRunner = { - new PythonRunner(Seq(ChainedPythonFunctions(Seq(func)))) { + new PythonRunner(Seq(ChainedPythonFunctions(Seq(func))), jobArtifactUUID) { override val pythonExec: String = SQLConf.get.pysparkWorkerPythonExecutable.getOrElse( funcs.head.funcs.head.pythonExec) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala index 6a952d9099e8..22083e0473b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -33,9 +33,10 @@ abstract class BasePythonUDFRunner( funcs: Seq[ChainedPythonFunctions], evalType: Int, argOffsets: Array[Array[Int]], - pythonMetrics: Map[String, SQLMetric]) + pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String]) extends BasePythonRunner[Array[Byte], Array[Byte]]( - funcs, evalType, argOffsets) { + funcs, evalType, argOffsets, jobArtifactUUID) { override val pythonExec: String = SQLConf.get.pysparkWorkerPythonExecutable.getOrElse( @@ -105,8 +106,9 @@ class PythonUDFRunner( funcs: Seq[ChainedPythonFunctions], evalType: Int, argOffsets: Array[Array[Int]], - pythonMetrics: Map[String, SQLMetric]) - extends BasePythonUDFRunner(funcs, evalType, argOffsets, pythonMetrics) { + pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String]) + extends BasePythonUDFRunner(funcs, evalType, argOffsets, pythonMetrics, jobArtifactUUID) { protected override def newWriterThread( env: SparkEnv, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala index e6a65dd61dce..3d43c417dcb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -22,7 +22,7 @@ import java.io.File import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.{JobArtifactSet, SparkEnv, TaskContext} import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -255,6 +255,7 @@ case class WindowInPandasExec( val allInputs = windowBoundsInput ++ dataInputs val allInputTypes = allInputs.map(_.dataType) val spillSize = longMetric("spillSize") + val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) // Start processing. child.execute().mapPartitions { iter => @@ -388,7 +389,8 @@ case class WindowInPandasExec( sessionLocalTimeZone, largeVarTypes, pythonRunnerConf, - pythonMetrics).compute(pythonInput, context.partitionId(), context) + pythonMetrics, + jobArtifactUUID).compute(pythonInput, context.partitionId(), context) val joined = new JoinedRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 2ab9d3c525cb..177a25b45fc3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -176,9 +176,7 @@ class SessionResourceLoader(session: SparkSession) extends FunctionResourceLoade * to add the jar to its hive client for the current session. Hence, it still needs to be in * [[SessionState]]. */ - def addJar(path: String): Unit = addJar(path: String, sessionId = "default") - - private[spark] def addJar(path: String, sessionId: String): Unit = { + def addJar(path: String): Unit = { val uri = Utils.resolveURI(path) resolveJars(uri).foreach { p => session.sparkContext.addJar(p)