Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

package org.apache.spark.sql.connect.artifact

import java.io.File
import java.net.{URL, URLClassLoader}
import java.nio.file.{Files, Path, Paths, StandardCopyOption}
import java.util.concurrent.CopyOnWriteArrayList
import javax.ws.rs.core.UriBuilder

import scala.collection.JavaConverters._
import scala.reflect.ClassTag
Expand Down Expand Up @@ -99,16 +101,17 @@ class SparkConnectArtifactManager private[connect] {
private[connect] def addArtifact(
sessionHolder: SessionHolder,
remoteRelativePath: Path,
serverLocalStagingPath: Path): Unit = {
serverLocalStagingPath: Path,
fragment: Option[String]): Unit = {
require(!remoteRelativePath.isAbsolute)
if (remoteRelativePath.startsWith("cache/")) {
if (remoteRelativePath.startsWith(s"cache${File.separator}")) {
val tmpFile = serverLocalStagingPath.toFile
Utils.tryWithSafeFinallyAndFailureCallbacks {
val blockManager = sessionHolder.session.sparkContext.env.blockManager
val blockId = CacheId(
userId = sessionHolder.userId,
sessionId = sessionHolder.sessionId,
hash = remoteRelativePath.toString.stripPrefix("cache/"))
hash = remoteRelativePath.toString.stripPrefix(s"cache${File.separator}"))
val updater = blockManager.TempFileBasedBlockStoreUpdater(
blockId = blockId,
level = StorageLevel.MEMORY_AND_DISK_SER,
Expand All @@ -118,9 +121,10 @@ class SparkConnectArtifactManager private[connect] {
tellMaster = false)
updater.save()
}(catchBlock = { tmpFile.delete() })
} else if (remoteRelativePath.startsWith("classes/")) {
} else if (remoteRelativePath.startsWith(s"classes${File.separator}")) {
// Move class files to common location (shared among all users)
val target = classArtifactDir.resolve(remoteRelativePath.toString.stripPrefix("classes/"))
val target = classArtifactDir.resolve(
remoteRelativePath.toString.stripPrefix(s"classes${File.separator}"))
Files.createDirectories(target.getParent)
// Allow overwriting class files to capture updates to classes.
Files.move(serverLocalStagingPath, target, StandardCopyOption.REPLACE_EXISTING)
Expand All @@ -135,17 +139,21 @@ class SparkConnectArtifactManager private[connect] {
s"Jars cannot be overwritten.")
}
Files.move(serverLocalStagingPath, target)
if (remoteRelativePath.startsWith("jars/")) {
if (remoteRelativePath.startsWith(s"jars${File.separator}")) {
// Adding Jars to the underlying spark context (visible to all users)
sessionHolder.session.sessionState.resourceLoader.addJar(target.toString)
jarsList.add(target)
} else if (remoteRelativePath.startsWith("pyfiles/")) {
} else if (remoteRelativePath.startsWith(s"pyfiles${File.separator}")) {
sessionHolder.session.sparkContext.addFile(target.toString)
val stringRemotePath = remoteRelativePath.toString
if (stringRemotePath.endsWith(".zip") || stringRemotePath.endsWith(
".egg") || stringRemotePath.endsWith(".jar")) {
pythonIncludeList.add(target.getFileName.toString)
}
} else if (remoteRelativePath.startsWith(s"archives${File.separator}")) {
val canonicalUri =
fragment.map(UriBuilder.fromUri(target.toUri).fragment).getOrElse(target.toUri)
sessionHolder.session.sparkContext.addArchive(canonicalUri.toString)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.spark.sql.connect.service

import java.io.File
import java.nio.file.{Files, Path, Paths}
import java.util.zip.{CheckedOutputStream, CRC32}

Expand Down Expand Up @@ -85,7 +86,7 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr
}

protected def addStagedArtifactToArtifactManager(artifact: StagedArtifact): Unit = {
artifactManager.addArtifact(holder, artifact.path, artifact.stagedPath)
artifactManager.addArtifact(holder, artifact.path, artifact.stagedPath, artifact.fragment)
}

/**
Expand Down Expand Up @@ -148,7 +149,21 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr
* Handles rebuilding an artifact from bytes sent over the wire.
*/
class StagedArtifact(val name: String) {
val path: Path = Paths.get(name)
// Workaround to keep the fragment.
val (canonicalFileName: String, fragment: Option[String]) =
if (name.startsWith(s"archives${File.separator}")) {
val splits = name.split("#")
assert(splits.length <= 2, "'#' in the path is not supported for adding an archive.")
if (splits.length == 2) {
(splits(0), Some(splits(1)))
} else {
(splits(0), None)
}
} else {
(name, None)
}

val path: Path = Paths.get(canonicalFileName)
val stagedPath: Path = stagingDir.resolve(path)

Files.createDirectories(stagedPath.getParent)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper {
FileUtils.copyDirectory(artifactPath.toFile, copyDir.toFile)
val stagingPath = copyDir.resolve("smallJar.jar")
val remotePath = Paths.get("jars/smallJar.jar")
artifactManager.addArtifact(sessionHolder, remotePath, stagingPath)
artifactManager.addArtifact(sessionHolder, remotePath, stagingPath, None)

val jarList = spark.sparkContext.listJars()
assert(jarList.exists(_.contains(remotePath.toString)))
Expand All @@ -60,7 +60,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper {
val stagingPath = copyDir.resolve("smallClassFile.class")
val remotePath = Paths.get("classes/smallClassFile.class")
assert(stagingPath.toFile.exists())
artifactManager.addArtifact(sessionHolder, remotePath, stagingPath)
artifactManager.addArtifact(sessionHolder, remotePath, stagingPath, None)

val classFileDirectory = artifactManager.classArtifactDir
val movedClassFile = classFileDirectory.resolve("smallClassFile.class").toFile
Expand All @@ -73,7 +73,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper {
val stagingPath = copyDir.resolve("Hello.class")
val remotePath = Paths.get("classes/Hello.class")
assert(stagingPath.toFile.exists())
artifactManager.addArtifact(sessionHolder, remotePath, stagingPath)
artifactManager.addArtifact(sessionHolder, remotePath, stagingPath, None)

val classFileDirectory = artifactManager.classArtifactDir
val movedClassFile = classFileDirectory.resolve("Hello.class").toFile
Expand All @@ -96,7 +96,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper {
val stagingPath = copyDir.resolve("Hello.class")
val remotePath = Paths.get("classes/Hello.class")
assert(stagingPath.toFile.exists())
artifactManager.addArtifact(sessionHolder, remotePath, stagingPath)
artifactManager.addArtifact(sessionHolder, remotePath, stagingPath, None)

val classFileDirectory = artifactManager.classArtifactDir
val movedClassFile = classFileDirectory.resolve("Hello.class").toFile
Expand All @@ -123,7 +123,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper {
val blockManager = spark.sparkContext.env.blockManager
val blockId = CacheId(session.userId, session.sessionId, "abc")
try {
artifactManager.addArtifact(session, remotePath, stagingPath)
artifactManager.addArtifact(session, remotePath, stagingPath, None)
val bytes = blockManager.getLocalBytes(blockId)
assert(bytes.isDefined)
val readback = new String(bytes.get.toByteBuffer().array(), StandardCharsets.UTF_8)
Expand All @@ -141,7 +141,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper {
Files.write(path.toPath, "test".getBytes(StandardCharsets.UTF_8))
val session = sessionHolder()
val remotePath = Paths.get("pyfiles/abc.zip")
artifactManager.addArtifact(session, remotePath, stagingPath)
artifactManager.addArtifact(session, remotePath, stagingPath, None)
assert(artifactManager.getSparkConnectPythonIncludes == Seq("abc.zip"))
}
}
Expand Down
52 changes: 42 additions & 10 deletions python/pyspark/sql/connect/client/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

JAR_PREFIX: str = "jars"
PYFILE_PREFIX: str = "pyfiles"
ARCHIVE_PREFIX: str = "archives"


class LocalData(metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -102,6 +103,10 @@ def new_pyfile_artifact(file_name: str, storage: LocalData) -> Artifact:
return _new_artifact(PYFILE_PREFIX, "", file_name, storage)


def new_archive_artifact(file_name: str, storage: LocalData) -> Artifact:
return _new_artifact(ARCHIVE_PREFIX, "", file_name, storage)


def _new_artifact(
prefix: str, required_suffix: str, file_name: str, storage: LocalData
) -> Artifact:
Expand Down Expand Up @@ -136,12 +141,16 @@ def __init__(self, user_id: Optional[str], session_id: str, channel: grpc.Channe
self._stub = grpc_lib.SparkConnectServiceStub(channel)
self._session_id = session_id

def _parse_artifacts(self, path_or_uri: str, pyfile: bool) -> List[Artifact]:
def _parse_artifacts(self, path_or_uri: str, pyfile: bool, archive: bool) -> List[Artifact]:
# Currently only local files with .jar extension is supported.
uri = path_or_uri
if urlparse(path_or_uri).scheme == "": # Is path?
uri = Path(path_or_uri).absolute().as_uri()
parsed = urlparse(uri)
parsed = urlparse(path_or_uri)
# Check if it is a file from the scheme
if parsed.scheme == "":
# Similar with Utils.resolveURI.
fragment = parsed.fragment
parsed = urlparse(Path(url2pathname(parsed.path)).absolute().as_uri())
parsed = parsed._replace(fragment=fragment)

if parsed.scheme == "file":
local_path = url2pathname(parsed.path)
name = Path(local_path).name
Expand All @@ -154,29 +163,52 @@ def _parse_artifacts(self, path_or_uri: str, pyfile: bool) -> List[Artifact]:
sys.path.insert(1, local_path)
artifact = new_pyfile_artifact(name, LocalFile(local_path))
importlib.invalidate_caches()
elif archive and (
name.endswith(".zip")
or name.endswith(".jar")
or name.endswith(".tar.gz")
or name.endswith(".tgz")
or name.endswith(".tar")
):
assert any(name.endswith(s) for s in (".zip", ".jar", ".tar.gz", ".tgz", ".tar"))

if parsed.fragment != "":
# Minimal fix for the workaround of fragment handling in URI.
# This has a limitation - hash(#) in the file name would not work.
if "#" in local_path:
raise ValueError("'#' in the path is not supported for adding an archive.")
name = f"{name}#{parsed.fragment}"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually pretty ugly workaround to support fragment in URI (but I believe this is the minimized change). Maybe we should pass URI instead of file path in Artifacts (?) but I would like to avoid touching the whole implementation in this PR. cc @hvanhovell @vicennial


artifact = new_archive_artifact(name, LocalFile(local_path))
elif name.endswith(".jar"):
artifact = new_jar_artifact(name, LocalFile(local_path))
else:
raise RuntimeError(f"Unsupported file format: {local_path}")
return [artifact]
raise RuntimeError(f"Unsupported scheme: {uri}")
raise RuntimeError(f"Unsupported scheme: {parsed.scheme}")

def _create_requests(self, *path: str, pyfile: bool) -> Iterator[proto.AddArtifactsRequest]:
def _create_requests(
self, *path: str, pyfile: bool, archive: bool
) -> Iterator[proto.AddArtifactsRequest]:
"""Separated for the testing purpose."""
return self._add_artifacts(chain(*(self._parse_artifacts(p, pyfile=pyfile) for p in path)))
return self._add_artifacts(
chain(*(self._parse_artifacts(p, pyfile=pyfile, archive=archive) for p in path))
)

def _retrieve_responses(
self, requests: Iterator[proto.AddArtifactsRequest]
) -> proto.AddArtifactsResponse:
"""Separated for the testing purpose."""
return self._stub.AddArtifacts(requests)

def add_artifacts(self, *path: str, pyfile: bool) -> None:
def add_artifacts(self, *path: str, pyfile: bool, archive: bool) -> None:
"""
Add a single artifact to the session.
Currently only local files with .jar extension is supported.
"""
requests: Iterator[proto.AddArtifactsRequest] = self._create_requests(*path, pyfile=pyfile)
requests: Iterator[proto.AddArtifactsRequest] = self._create_requests(
*path, pyfile=pyfile, archive=archive
)
response: proto.AddArtifactsResponse = self._retrieve_responses(requests)
summaries: List[proto.AddArtifactsResponse.ArtifactSummary] = []

Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,8 +1237,8 @@ def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn:
else:
raise SparkConnectGrpcException(str(rpc_error)) from None

def add_artifacts(self, *path: str, pyfile: bool) -> None:
self._artifact_manager.add_artifacts(*path, pyfile=pyfile)
def add_artifacts(self, *path: str, pyfile: bool, archive: bool) -> None:
self._artifact_manager.add_artifacts(*path, pyfile=pyfile, archive=archive)


class RetryState:
Expand Down
11 changes: 9 additions & 2 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ def client(self) -> "SparkConnectClient":
"""
return self._client

def addArtifacts(self, *path: str, pyfile: bool = False) -> None:
def addArtifacts(self, *path: str, pyfile: bool = False, archive: bool = False) -> None:
"""
Add artifact(s) to the client session. Currently only local files are supported.

Expand All @@ -613,8 +613,15 @@ def addArtifacts(self, *path: str, pyfile: bool = False) -> None:
Artifact's URIs to add.
pyfile : bool
Whether to add them as Python dependencies such as .py, .egg, .zip or .jar files.
The pyfiles are directly inserted into the path when executing Python functions
in executors.
archive : bool
Whether to add them as archives such as .zip, .jar, .tar.gz, .tgz, or .tar files.
The archives are unpacked on the executor side automatically.
"""
self._client.add_artifacts(*path, pyfile=pyfile)
if pyfile and archive:
raise ValueError("'pyfile' and 'archive' cannot be True together.")
self._client.add_artifacts(*path, pyfile=pyfile, archive=archive)

addArtifact = addArtifacts

Expand Down
44 changes: 39 additions & 5 deletions python/pyspark/sql/tests/connect/client/test_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from pyspark.testing.connectutils import ReusedConnectTestCase, should_test_connect
from pyspark.testing.utils import SPARK_HOME
from pyspark import SparkFiles
from pyspark.sql.functions import udf

if should_test_connect:
Expand Down Expand Up @@ -48,7 +49,7 @@ def test_basic_requests(self):
file_name = "smallJar"
small_jar_path = os.path.join(self.artifact_file_path, f"{file_name}.jar")
response = self.artifact_manager._retrieve_responses(
self.artifact_manager._create_requests(small_jar_path, pyfile=False)
self.artifact_manager._create_requests(small_jar_path, pyfile=False, archive=False)
)
self.assertTrue(response.artifacts[0].name.endswith(f"{file_name}.jar"))

Expand All @@ -57,7 +58,9 @@ def test_single_chunk_artifact(self):
small_jar_path = os.path.join(self.artifact_file_path, f"{file_name}.jar")
small_jar_crc_path = os.path.join(self.artifact_crc_path, f"{file_name}.txt")

requests = list(self.artifact_manager._create_requests(small_jar_path, pyfile=False))
requests = list(
self.artifact_manager._create_requests(small_jar_path, pyfile=False, archive=False)
)
self.assertEqual(len(requests), 1)

request = requests[0]
Expand All @@ -79,7 +82,9 @@ def test_chunked_artifacts(self):
large_jar_path = os.path.join(self.artifact_file_path, f"{file_name}.jar")
large_jar_crc_path = os.path.join(self.artifact_crc_path, f"{file_name}.txt")

requests = list(self.artifact_manager._create_requests(large_jar_path, pyfile=False))
requests = list(
self.artifact_manager._create_requests(large_jar_path, pyfile=False, archive=False)
)
# Expected chunks = roundUp( file_size / chunk_size) = 12
# File size of `junitLargeJar.jar` is 384581 bytes.
large_jar_size = os.path.getsize(large_jar_path)
Expand Down Expand Up @@ -111,7 +116,9 @@ def test_batched_artifacts(self):
small_jar_crc_path = os.path.join(self.artifact_crc_path, f"{file_name}.txt")

requests = list(
self.artifact_manager._create_requests(small_jar_path, small_jar_path, pyfile=False)
self.artifact_manager._create_requests(
small_jar_path, small_jar_path, pyfile=False, archive=False
)
)
# Single request containing 2 artifacts.
self.assertEqual(len(requests), 1)
Expand Down Expand Up @@ -147,7 +154,12 @@ def test_single_chunked_and_chunked_artifact(self):

requests = list(
self.artifact_manager._create_requests(
small_jar_path, large_jar_path, small_jar_path, small_jar_path, pyfile=False
small_jar_path,
large_jar_path,
small_jar_path,
small_jar_path,
pyfile=False,
archive=False,
)
)
# There are a total of 14 requests.
Expand Down Expand Up @@ -237,6 +249,28 @@ def func(x):
self.spark.addArtifacts(f"{package_path}.zip", pyfile=True)
self.assertEqual(self.spark.range(1).select(func("id")).first()[0], 5)

def test_add_archive(self):
with tempfile.TemporaryDirectory() as d:
archive_path = os.path.join(d, "my_archive")
os.mkdir(archive_path)
pyfile_path = os.path.join(archive_path, "my_file.txt")
with open(pyfile_path, "w") as f:
_ = f.write("hello world!")
shutil.make_archive(archive_path, "zip", d, "my_archive")

@udf("string")
def func(x):
with open(
os.path.join(
SparkFiles.getRootDirectory(), "my_files", "my_archive", "my_file.txt"
),
"r",
) as my_file:
return my_file.read().strip()

self.spark.addArtifacts(f"{archive_path}.zip#my_files", archive=True)
self.assertEqual(self.spark.range(1).select(func("id")).first()[0], "hello world!")


if __name__ == "__main__":
from pyspark.sql.tests.connect.client.test_artifact import * # noqa: F401
Expand Down