From 4d771a2d9d317b4c7765a4b9ff3a52af5360a0c7 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 24 May 2023 11:23:16 +0900 Subject: [PATCH] Implement the archive support in SparkSession.addArtifacts --- .../SparkConnectArtifactManager.scala | 22 +++++--- .../SparkConnectAddArtifactsHandler.scala | 19 ++++++- .../artifact/ArtifactManagerSuite.scala | 12 ++--- python/pyspark/sql/connect/client/artifact.py | 52 +++++++++++++++---- python/pyspark/sql/connect/client/core.py | 4 +- python/pyspark/sql/connect/session.py | 11 +++- .../sql/tests/connect/client/test_artifact.py | 44 ++++++++++++++-- 7 files changed, 130 insertions(+), 34 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala index 7a36c46c6722..604108f68d22 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.connect.artifact +import java.io.File import java.net.{URL, URLClassLoader} import java.nio.file.{Files, Path, Paths, StandardCopyOption} import java.util.concurrent.CopyOnWriteArrayList +import javax.ws.rs.core.UriBuilder import scala.collection.JavaConverters._ import scala.reflect.ClassTag @@ -99,16 +101,17 @@ class SparkConnectArtifactManager private[connect] { private[connect] def addArtifact( sessionHolder: SessionHolder, remoteRelativePath: Path, - serverLocalStagingPath: Path): Unit = { + serverLocalStagingPath: Path, + fragment: Option[String]): Unit = { require(!remoteRelativePath.isAbsolute) - if (remoteRelativePath.startsWith("cache/")) { + if (remoteRelativePath.startsWith(s"cache${File.separator}")) { val tmpFile = serverLocalStagingPath.toFile Utils.tryWithSafeFinallyAndFailureCallbacks { val blockManager = sessionHolder.session.sparkContext.env.blockManager val blockId = CacheId( userId = sessionHolder.userId, sessionId = sessionHolder.sessionId, - hash = remoteRelativePath.toString.stripPrefix("cache/")) + hash = remoteRelativePath.toString.stripPrefix(s"cache${File.separator}")) val updater = blockManager.TempFileBasedBlockStoreUpdater( blockId = blockId, level = StorageLevel.MEMORY_AND_DISK_SER, @@ -118,9 +121,10 @@ class SparkConnectArtifactManager private[connect] { tellMaster = false) updater.save() }(catchBlock = { tmpFile.delete() }) - } else if (remoteRelativePath.startsWith("classes/")) { + } else if (remoteRelativePath.startsWith(s"classes${File.separator}")) { // Move class files to common location (shared among all users) - val target = classArtifactDir.resolve(remoteRelativePath.toString.stripPrefix("classes/")) + val target = classArtifactDir.resolve( + remoteRelativePath.toString.stripPrefix(s"classes${File.separator}")) Files.createDirectories(target.getParent) // Allow overwriting class files to capture updates to classes. Files.move(serverLocalStagingPath, target, StandardCopyOption.REPLACE_EXISTING) @@ -135,17 +139,21 @@ class SparkConnectArtifactManager private[connect] { s"Jars cannot be overwritten.") } Files.move(serverLocalStagingPath, target) - if (remoteRelativePath.startsWith("jars/")) { + if (remoteRelativePath.startsWith(s"jars${File.separator}")) { // Adding Jars to the underlying spark context (visible to all users) sessionHolder.session.sessionState.resourceLoader.addJar(target.toString) jarsList.add(target) - } else if (remoteRelativePath.startsWith("pyfiles/")) { + } else if (remoteRelativePath.startsWith(s"pyfiles${File.separator}")) { sessionHolder.session.sparkContext.addFile(target.toString) val stringRemotePath = remoteRelativePath.toString if (stringRemotePath.endsWith(".zip") || stringRemotePath.endsWith( ".egg") || stringRemotePath.endsWith(".jar")) { pythonIncludeList.add(target.getFileName.toString) } + } else if (remoteRelativePath.startsWith(s"archives${File.separator}")) { + val canonicalUri = + fragment.map(UriBuilder.fromUri(target.toUri).fragment).getOrElse(target.toUri) + sessionHolder.session.sparkContext.addArchive(canonicalUri.toString) } } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala index 99e92e42fffc..f8bdb58ed852 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.connect.service +import java.io.File import java.nio.file.{Files, Path, Paths} import java.util.zip.{CheckedOutputStream, CRC32} @@ -85,7 +86,7 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr } protected def addStagedArtifactToArtifactManager(artifact: StagedArtifact): Unit = { - artifactManager.addArtifact(holder, artifact.path, artifact.stagedPath) + artifactManager.addArtifact(holder, artifact.path, artifact.stagedPath, artifact.fragment) } /** @@ -148,7 +149,21 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr * Handles rebuilding an artifact from bytes sent over the wire. */ class StagedArtifact(val name: String) { - val path: Path = Paths.get(name) + // Workaround to keep the fragment. + val (canonicalFileName: String, fragment: Option[String]) = + if (name.startsWith(s"archives${File.separator}")) { + val splits = name.split("#") + assert(splits.length <= 2, "'#' in the path is not supported for adding an archive.") + if (splits.length == 2) { + (splits(0), Some(splits(1))) + } else { + (splits(0), None) + } + } else { + (name, None) + } + + val path: Path = Paths.get(canonicalFileName) val stagedPath: Path = stagingDir.resolve(path) Files.createDirectories(stagedPath.getParent) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala index 291eadb07c45..b87c6742bdcd 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala @@ -48,7 +48,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { FileUtils.copyDirectory(artifactPath.toFile, copyDir.toFile) val stagingPath = copyDir.resolve("smallJar.jar") val remotePath = Paths.get("jars/smallJar.jar") - artifactManager.addArtifact(sessionHolder, remotePath, stagingPath) + artifactManager.addArtifact(sessionHolder, remotePath, stagingPath, None) val jarList = spark.sparkContext.listJars() assert(jarList.exists(_.contains(remotePath.toString))) @@ -60,7 +60,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val stagingPath = copyDir.resolve("smallClassFile.class") val remotePath = Paths.get("classes/smallClassFile.class") assert(stagingPath.toFile.exists()) - artifactManager.addArtifact(sessionHolder, remotePath, stagingPath) + artifactManager.addArtifact(sessionHolder, remotePath, stagingPath, None) val classFileDirectory = artifactManager.classArtifactDir val movedClassFile = classFileDirectory.resolve("smallClassFile.class").toFile @@ -73,7 +73,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val stagingPath = copyDir.resolve("Hello.class") val remotePath = Paths.get("classes/Hello.class") assert(stagingPath.toFile.exists()) - artifactManager.addArtifact(sessionHolder, remotePath, stagingPath) + artifactManager.addArtifact(sessionHolder, remotePath, stagingPath, None) val classFileDirectory = artifactManager.classArtifactDir val movedClassFile = classFileDirectory.resolve("Hello.class").toFile @@ -96,7 +96,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val stagingPath = copyDir.resolve("Hello.class") val remotePath = Paths.get("classes/Hello.class") assert(stagingPath.toFile.exists()) - artifactManager.addArtifact(sessionHolder, remotePath, stagingPath) + artifactManager.addArtifact(sessionHolder, remotePath, stagingPath, None) val classFileDirectory = artifactManager.classArtifactDir val movedClassFile = classFileDirectory.resolve("Hello.class").toFile @@ -123,7 +123,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val blockManager = spark.sparkContext.env.blockManager val blockId = CacheId(session.userId, session.sessionId, "abc") try { - artifactManager.addArtifact(session, remotePath, stagingPath) + artifactManager.addArtifact(session, remotePath, stagingPath, None) val bytes = blockManager.getLocalBytes(blockId) assert(bytes.isDefined) val readback = new String(bytes.get.toByteBuffer().array(), StandardCharsets.UTF_8) @@ -141,7 +141,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { Files.write(path.toPath, "test".getBytes(StandardCharsets.UTF_8)) val session = sessionHolder() val remotePath = Paths.get("pyfiles/abc.zip") - artifactManager.addArtifact(session, remotePath, stagingPath) + artifactManager.addArtifact(session, remotePath, stagingPath, None) assert(artifactManager.getSparkConnectPythonIncludes == Seq("abc.zip")) } } diff --git a/python/pyspark/sql/connect/client/artifact.py b/python/pyspark/sql/connect/client/artifact.py index f06277e50684..64f89119e4fb 100644 --- a/python/pyspark/sql/connect/client/artifact.py +++ b/python/pyspark/sql/connect/client/artifact.py @@ -38,6 +38,7 @@ JAR_PREFIX: str = "jars" PYFILE_PREFIX: str = "pyfiles" +ARCHIVE_PREFIX: str = "archives" class LocalData(metaclass=abc.ABCMeta): @@ -102,6 +103,10 @@ def new_pyfile_artifact(file_name: str, storage: LocalData) -> Artifact: return _new_artifact(PYFILE_PREFIX, "", file_name, storage) +def new_archive_artifact(file_name: str, storage: LocalData) -> Artifact: + return _new_artifact(ARCHIVE_PREFIX, "", file_name, storage) + + def _new_artifact( prefix: str, required_suffix: str, file_name: str, storage: LocalData ) -> Artifact: @@ -136,12 +141,16 @@ def __init__(self, user_id: Optional[str], session_id: str, channel: grpc.Channe self._stub = grpc_lib.SparkConnectServiceStub(channel) self._session_id = session_id - def _parse_artifacts(self, path_or_uri: str, pyfile: bool) -> List[Artifact]: + def _parse_artifacts(self, path_or_uri: str, pyfile: bool, archive: bool) -> List[Artifact]: # Currently only local files with .jar extension is supported. - uri = path_or_uri - if urlparse(path_or_uri).scheme == "": # Is path? - uri = Path(path_or_uri).absolute().as_uri() - parsed = urlparse(uri) + parsed = urlparse(path_or_uri) + # Check if it is a file from the scheme + if parsed.scheme == "": + # Similar with Utils.resolveURI. + fragment = parsed.fragment + parsed = urlparse(Path(url2pathname(parsed.path)).absolute().as_uri()) + parsed = parsed._replace(fragment=fragment) + if parsed.scheme == "file": local_path = url2pathname(parsed.path) name = Path(local_path).name @@ -154,16 +163,37 @@ def _parse_artifacts(self, path_or_uri: str, pyfile: bool) -> List[Artifact]: sys.path.insert(1, local_path) artifact = new_pyfile_artifact(name, LocalFile(local_path)) importlib.invalidate_caches() + elif archive and ( + name.endswith(".zip") + or name.endswith(".jar") + or name.endswith(".tar.gz") + or name.endswith(".tgz") + or name.endswith(".tar") + ): + assert any(name.endswith(s) for s in (".zip", ".jar", ".tar.gz", ".tgz", ".tar")) + + if parsed.fragment != "": + # Minimal fix for the workaround of fragment handling in URI. + # This has a limitation - hash(#) in the file name would not work. + if "#" in local_path: + raise ValueError("'#' in the path is not supported for adding an archive.") + name = f"{name}#{parsed.fragment}" + + artifact = new_archive_artifact(name, LocalFile(local_path)) elif name.endswith(".jar"): artifact = new_jar_artifact(name, LocalFile(local_path)) else: raise RuntimeError(f"Unsupported file format: {local_path}") return [artifact] - raise RuntimeError(f"Unsupported scheme: {uri}") + raise RuntimeError(f"Unsupported scheme: {parsed.scheme}") - def _create_requests(self, *path: str, pyfile: bool) -> Iterator[proto.AddArtifactsRequest]: + def _create_requests( + self, *path: str, pyfile: bool, archive: bool + ) -> Iterator[proto.AddArtifactsRequest]: """Separated for the testing purpose.""" - return self._add_artifacts(chain(*(self._parse_artifacts(p, pyfile=pyfile) for p in path))) + return self._add_artifacts( + chain(*(self._parse_artifacts(p, pyfile=pyfile, archive=archive) for p in path)) + ) def _retrieve_responses( self, requests: Iterator[proto.AddArtifactsRequest] @@ -171,12 +201,14 @@ def _retrieve_responses( """Separated for the testing purpose.""" return self._stub.AddArtifacts(requests) - def add_artifacts(self, *path: str, pyfile: bool) -> None: + def add_artifacts(self, *path: str, pyfile: bool, archive: bool) -> None: """ Add a single artifact to the session. Currently only local files with .jar extension is supported. """ - requests: Iterator[proto.AddArtifactsRequest] = self._create_requests(*path, pyfile=pyfile) + requests: Iterator[proto.AddArtifactsRequest] = self._create_requests( + *path, pyfile=pyfile, archive=archive + ) response: proto.AddArtifactsResponse = self._retrieve_responses(requests) summaries: List[proto.AddArtifactsResponse.ArtifactSummary] = [] diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 6aa25d973689..cd06419d969b 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -1237,8 +1237,8 @@ def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn: else: raise SparkConnectGrpcException(str(rpc_error)) from None - def add_artifacts(self, *path: str, pyfile: bool) -> None: - self._artifact_manager.add_artifacts(*path, pyfile=pyfile) + def add_artifacts(self, *path: str, pyfile: bool, archive: bool) -> None: + self._artifact_manager.add_artifacts(*path, pyfile=pyfile, archive=archive) class RetryState: diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 3341b88eded9..7932ab540814 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -601,7 +601,7 @@ def client(self) -> "SparkConnectClient": """ return self._client - def addArtifacts(self, *path: str, pyfile: bool = False) -> None: + def addArtifacts(self, *path: str, pyfile: bool = False, archive: bool = False) -> None: """ Add artifact(s) to the client session. Currently only local files are supported. @@ -613,8 +613,15 @@ def addArtifacts(self, *path: str, pyfile: bool = False) -> None: Artifact's URIs to add. pyfile : bool Whether to add them as Python dependencies such as .py, .egg, .zip or .jar files. + The pyfiles are directly inserted into the path when executing Python functions + in executors. + archive : bool + Whether to add them as archives such as .zip, .jar, .tar.gz, .tgz, or .tar files. + The archives are unpacked on the executor side automatically. """ - self._client.add_artifacts(*path, pyfile=pyfile) + if pyfile and archive: + raise ValueError("'pyfile' and 'archive' cannot be True together.") + self._client.add_artifacts(*path, pyfile=pyfile, archive=archive) addArtifact = addArtifacts diff --git a/python/pyspark/sql/tests/connect/client/test_artifact.py b/python/pyspark/sql/tests/connect/client/test_artifact.py index 73f47486bab4..2bff3fd5bc46 100644 --- a/python/pyspark/sql/tests/connect/client/test_artifact.py +++ b/python/pyspark/sql/tests/connect/client/test_artifact.py @@ -21,6 +21,7 @@ from pyspark.testing.connectutils import ReusedConnectTestCase, should_test_connect from pyspark.testing.utils import SPARK_HOME +from pyspark import SparkFiles from pyspark.sql.functions import udf if should_test_connect: @@ -48,7 +49,7 @@ def test_basic_requests(self): file_name = "smallJar" small_jar_path = os.path.join(self.artifact_file_path, f"{file_name}.jar") response = self.artifact_manager._retrieve_responses( - self.artifact_manager._create_requests(small_jar_path, pyfile=False) + self.artifact_manager._create_requests(small_jar_path, pyfile=False, archive=False) ) self.assertTrue(response.artifacts[0].name.endswith(f"{file_name}.jar")) @@ -57,7 +58,9 @@ def test_single_chunk_artifact(self): small_jar_path = os.path.join(self.artifact_file_path, f"{file_name}.jar") small_jar_crc_path = os.path.join(self.artifact_crc_path, f"{file_name}.txt") - requests = list(self.artifact_manager._create_requests(small_jar_path, pyfile=False)) + requests = list( + self.artifact_manager._create_requests(small_jar_path, pyfile=False, archive=False) + ) self.assertEqual(len(requests), 1) request = requests[0] @@ -79,7 +82,9 @@ def test_chunked_artifacts(self): large_jar_path = os.path.join(self.artifact_file_path, f"{file_name}.jar") large_jar_crc_path = os.path.join(self.artifact_crc_path, f"{file_name}.txt") - requests = list(self.artifact_manager._create_requests(large_jar_path, pyfile=False)) + requests = list( + self.artifact_manager._create_requests(large_jar_path, pyfile=False, archive=False) + ) # Expected chunks = roundUp( file_size / chunk_size) = 12 # File size of `junitLargeJar.jar` is 384581 bytes. large_jar_size = os.path.getsize(large_jar_path) @@ -111,7 +116,9 @@ def test_batched_artifacts(self): small_jar_crc_path = os.path.join(self.artifact_crc_path, f"{file_name}.txt") requests = list( - self.artifact_manager._create_requests(small_jar_path, small_jar_path, pyfile=False) + self.artifact_manager._create_requests( + small_jar_path, small_jar_path, pyfile=False, archive=False + ) ) # Single request containing 2 artifacts. self.assertEqual(len(requests), 1) @@ -147,7 +154,12 @@ def test_single_chunked_and_chunked_artifact(self): requests = list( self.artifact_manager._create_requests( - small_jar_path, large_jar_path, small_jar_path, small_jar_path, pyfile=False + small_jar_path, + large_jar_path, + small_jar_path, + small_jar_path, + pyfile=False, + archive=False, ) ) # There are a total of 14 requests. @@ -237,6 +249,28 @@ def func(x): self.spark.addArtifacts(f"{package_path}.zip", pyfile=True) self.assertEqual(self.spark.range(1).select(func("id")).first()[0], 5) + def test_add_archive(self): + with tempfile.TemporaryDirectory() as d: + archive_path = os.path.join(d, "my_archive") + os.mkdir(archive_path) + pyfile_path = os.path.join(archive_path, "my_file.txt") + with open(pyfile_path, "w") as f: + _ = f.write("hello world!") + shutil.make_archive(archive_path, "zip", d, "my_archive") + + @udf("string") + def func(x): + with open( + os.path.join( + SparkFiles.getRootDirectory(), "my_files", "my_archive", "my_file.txt" + ), + "r", + ) as my_file: + return my_file.read().strip() + + self.spark.addArtifacts(f"{archive_path}.zip#my_files", archive=True) + self.assertEqual(self.spark.range(1).select(func("id")).first()[0], "hello world!") + if __name__ == "__main__": from pyspark.sql.tests.connect.client.test_artifact import * # noqa: F401