diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 5178013e455b..a9b6f102a512 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -288,7 +288,17 @@ object CheckConnectJvmClientCompatibility { // SQLImplicits ProblemFilters.exclude[Problem]("org.apache.spark.sql.SQLImplicits.rddToDatasetHolder"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.SQLImplicits._sqlContext")) + ProblemFilters.exclude[Problem]("org.apache.spark.sql.SQLImplicits._sqlContext"), + + // Artifact Manager + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.artifact.ArtifactManager"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.artifact.ArtifactManager$"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.artifact.util.ArtifactUtils"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.artifact.util.ArtifactUtils$")) checkMiMaCompatibility(clientJar, sqlJar, includedRules, excludeRules) } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index 1a5944676f5f..f7aa98af2fa3 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -20,6 +20,7 @@ import java.util.concurrent.TimeUnit import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.connect.common.config.ConnectCommon +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.buildConf object Connect { @@ -206,20 +207,6 @@ object Connect { .intConf .createWithDefault(1024) - val CONNECT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL = - buildStaticConf("spark.connect.copyFromLocalToFs.allowDestLocal") - .internal() - .doc(""" - |Allow `spark.copyFromLocalToFs` destination to be local file system - | path on spark driver node when - |`spark.connect.copyFromLocalToFs.allowDestLocal` is true. - |This will allow user to overwrite arbitrary file on spark - |driver node we should only enable it for testing purpose. - |""".stripMargin) - .version("3.5.0") - .booleanConf - .createWithDefault(false) - val CONNECT_UI_STATEMENT_LIMIT = buildStaticConf("spark.sql.connect.ui.retainedStatements") .doc("The number of statements kept in the Spark Connect UI history.") @@ -227,6 +214,17 @@ object Connect { .intConf .createWithDefault(200) + val CONNECT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL = + buildStaticConf("spark.connect.copyFromLocalToFs.allowDestLocal") + .internal() + .doc(s""" + |(Deprecated since Spark 4.0, please set + |'${SQLConf.ARTIFACT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL.key}' instead. + |""".stripMargin) + .version("3.5.0") + .booleanConf + .createWithDefault(false) + val CONNECT_UI_SESSION_LIMIT = buildStaticConf("spark.sql.connect.ui.retainedSessions") .doc("The number of client sessions kept in the Spark Connect UI history.") .version("3.5.0") diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 373ae0f90c6d..4a0aa7e55898 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -942,7 +942,7 @@ class SparkConnectPlanner( command = fun.getCommand.toByteArray.toImmutableArraySeq, // Empty environment variables envVars = Maps.newHashMap(), - pythonIncludes = sessionHolder.artifactManager.getSparkConnectPythonIncludes.asJava, + pythonIncludes = sessionHolder.artifactManager.getPythonIncludes.asJava, pythonExec = pythonExec, pythonVer = fun.getPythonVer, // Empty broadcast variables @@ -996,7 +996,7 @@ class SparkConnectPlanner( private def transformCachedLocalRelation(rel: proto.CachedLocalRelation): LogicalPlan = { val blockManager = session.sparkContext.env.blockManager - val blockId = CacheId(sessionHolder.userId, sessionHolder.sessionId, rel.getHash) + val blockId = CacheId(sessionHolder.session.sessionUUID, rel.getHash) val bytes = blockManager.getLocalBytes(blockId) bytes .map { blockData => @@ -1014,7 +1014,7 @@ class SparkConnectPlanner( .getOrElse { throw InvalidPlanInput( s"Not found any cached local relation with the hash: ${blockId.hash} in " + - s"the session ${blockId.sessionId} for the user id ${blockId.userId}.") + s"the session with sessionUUID ${blockId.sessionUUID}.") } } @@ -1633,7 +1633,7 @@ class SparkConnectPlanner( command = fun.getCommand.toByteArray.toImmutableArraySeq, // Empty environment variables envVars = Maps.newHashMap(), - pythonIncludes = sessionHolder.artifactManager.getSparkConnectPythonIncludes.asJava, + pythonIncludes = sessionHolder.artifactManager.getPythonIncludes.asJava, pythonExec = pythonExec, pythonVer = fun.getPythonVer, // Empty broadcast variables diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 0c55e30ba501..fd7c10d5c400 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -27,18 +27,16 @@ import scala.jdk.CollectionConverters._ import com.google.common.base.Ticker import com.google.common.cache.CacheBuilder -import org.apache.spark.{JobArtifactSet, SparkException, SparkSQLException} +import org.apache.spark.{SparkException, SparkSQLException} import org.apache.spark.internal.Logging import org.apache.spark.sql.DataFrame import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper import org.apache.spark.sql.connect.service.SessionHolder.{ERROR_CACHE_SIZE, ERROR_CACHE_TIMEOUT_SEC} import org.apache.spark.sql.streaming.StreamingQueryListener import org.apache.spark.util.SystemClock -import org.apache.spark.util.Utils // Unique key identifying session by combination of user, and session id case class SessionKey(userId: String, sessionId: String) @@ -166,7 +164,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio interruptedIds.toSeq } - private[connect] lazy val artifactManager = new SparkConnectArtifactManager(this) + private[connect] def artifactManager = session.artifactManager /** * Add an artifact to this SparkConnect session. @@ -238,27 +236,13 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio eventManager.postClosed() } - /** - * Execute a block of code using this session's classloader. - * @param f - * @tparam T - */ - def withContextClassLoader[T](f: => T): T = { - // Needed for deserializing and evaluating the UDF on the driver - Utils.withContextClassLoader(classloader) { - JobArtifactSet.withActiveJobArtifactState(artifactManager.state) { - f - } - } - } - /** * Execute a block of code with this session as the active SparkConnect session. * @param f * @tparam T */ def withSession[T](f: SparkSession => T): T = { - withContextClassLoader { + artifactManager.withResources { session.withActive { f(session) } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala index 636054198fbf..e664e07dce11 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala @@ -30,8 +30,8 @@ import io.grpc.stub.StreamObserver import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse} import org.apache.spark.connect.proto.AddArtifactsResponse.ArtifactSummary -import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager -import org.apache.spark.sql.connect.artifact.util.ArtifactUtils +import org.apache.spark.sql.artifact.ArtifactManager +import org.apache.spark.sql.artifact.util.ArtifactUtils import org.apache.spark.util.Utils /** @@ -101,8 +101,7 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr // We do not store artifacts that fail the CRC. The failure is reported in the artifact // summary and it is up to the client to decide whether to retry sending the artifact. if (artifact.getCrcStatus.contains(true)) { - if (artifact.path.startsWith( - SparkConnectArtifactManager.forwardToFSPrefix + File.separator)) { + if (artifact.path.startsWith(ArtifactManager.forwardToFSPrefix + File.separator)) { holder.artifactManager.uploadArtifactToFs(artifact.path, artifact.stagedPath) } else { addStagedArtifactToArtifactManager(artifact) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectArtifactStatusesHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectArtifactStatusesHandler.scala index 325832ac07e6..78def077f2dd 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectArtifactStatusesHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectArtifactStatusesHandler.scala @@ -33,7 +33,7 @@ class SparkConnectArtifactStatusesHandler( .getOrCreateIsolatedSession(userId, sessionId) .session val blockManager = session.sparkContext.env.blockManager - blockManager.getStatus(CacheId(userId, sessionId, hash)).isDefined + blockManager.getStatus(CacheId(session.sessionUUID, hash)).isDefined } def handle(request: proto.ArtifactStatusesRequest): Unit = { diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala index 2eaa8c8383e3..bb51b0a79820 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala @@ -163,7 +163,7 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { SimplePythonFunction( command = fcn(sparkPythonPath).toImmutableArraySeq, envVars = mutable.Map("PYTHONPATH" -> sparkPythonPath).asJava, - pythonIncludes = sessionHolder.artifactManager.getSparkConnectPythonIncludes.asJava, + pythonIncludes = sessionHolder.artifactManager.getPythonIncludes.asJava, pythonExec = IntegratedUDFTestUtils.pythonExec, pythonVer = IntegratedUDFTestUtils.pythonVer, broadcastVars = Lists.newArrayList(), diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 456b4edf9383..585d9a886b47 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -190,8 +190,8 @@ class UnrecognizedBlockId(name: String) extends SparkException(s"Failed to parse $name into a block ID") @DeveloperApi -case class CacheId(userId: String, sessionId: String, hash: String) extends BlockId { - override def name: String = s"cache_${userId}_${sessionId}_$hash" +case class CacheId(sessionUUID: String, hash: String) extends BlockId { + override def name: String = s"cache_${sessionUUID}_$hash" } @DeveloperApi diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e64c33382dce..8c22f8473e69 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -2057,10 +2057,10 @@ private[spark] class BlockManager( * * @return The number of blocks removed. */ - def removeCache(userId: String, sessionId: String): Int = { - logDebug(s"Removing cache of user id = $userId in the session $sessionId") + def removeCache(sessionUUID: String): Int = { + logDebug(s"Removing cache of spark session with UUID: $sessionUUID") val blocksToRemove = blockInfoManager.entries.map(_._1).collect { - case cid: CacheId if cid.userId == userId && cid.sessionId == sessionId => cid + case cid: CacheId if cid.sessionUUID == sessionUUID => cid } blocksToRemove.foreach { blockId => removeBlock(blockId) } blocksToRemove.size diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index d080b16fdc5c..463212290877 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -90,6 +90,16 @@ object MimaExcludes { // SPARK-43299: Convert StreamingQueryException in Scala Client ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryException"), + // SPARK-45856: Move ArtifactManager from Spark Connect into SparkSession (sql/core) + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.CacheId.apply"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.CacheId.userId"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.CacheId.sessionId"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.CacheId.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.CacheId.copy$default$3"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.CacheId.this"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.storage.CacheId$"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.CacheId.apply"), + (problem: Problem) => problem match { case MissingClassProblem(cls) => !cls.fullName.startsWith("org.sparkproject.jpmml") && !cls.fullName.startsWith("org.sparkproject.dmg.pmml") diff --git a/python/pyspark/ml/tests/connect/test_connect_classification.py b/python/pyspark/ml/tests/connect/test_connect_classification.py index ccf7c346be72..1f811c774cbd 100644 --- a/python/pyspark/ml/tests/connect/test_connect_classification.py +++ b/python/pyspark/ml/tests/connect/test_connect_classification.py @@ -38,7 +38,7 @@ class ClassificationTestsOnConnect(ClassificationTestsMixin, unittest.TestCase): def setUp(self) -> None: self.spark = ( SparkSession.builder.remote("local[2]") - .config("spark.connect.copyFromLocalToFs.allowDestLocal", "true") + .config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true") .getOrCreate() ) diff --git a/python/pyspark/ml/tests/connect/test_connect_pipeline.py b/python/pyspark/ml/tests/connect/test_connect_pipeline.py index 6925b2482f24..6a895e892397 100644 --- a/python/pyspark/ml/tests/connect/test_connect_pipeline.py +++ b/python/pyspark/ml/tests/connect/test_connect_pipeline.py @@ -30,7 +30,7 @@ class PipelineTestsOnConnect(PipelineTestsMixin, unittest.TestCase): def setUp(self) -> None: self.spark = ( SparkSession.builder.remote("local[2]") - .config("spark.connect.copyFromLocalToFs.allowDestLocal", "true") + .config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true") .getOrCreate() ) diff --git a/python/pyspark/ml/tests/connect/test_connect_tuning.py b/python/pyspark/ml/tests/connect/test_connect_tuning.py index d7dbb00b5e17..7b10d91da064 100644 --- a/python/pyspark/ml/tests/connect/test_connect_tuning.py +++ b/python/pyspark/ml/tests/connect/test_connect_tuning.py @@ -30,7 +30,7 @@ class CrossValidatorTestsOnConnect(CrossValidatorTestsMixin, unittest.TestCase): def setUp(self) -> None: self.spark = ( SparkSession.builder.remote("local[2]") - .config("spark.connect.copyFromLocalToFs.allowDestLocal", "true") + .config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true") .getOrCreate() ) diff --git a/python/pyspark/sql/tests/connect/client/test_artifact.py b/python/pyspark/sql/tests/connect/client/test_artifact.py index 7e9f9dbbf569..7fde0958e381 100644 --- a/python/pyspark/sql/tests/connect/client/test_artifact.py +++ b/python/pyspark/sql/tests/connect/client/test_artifact.py @@ -183,7 +183,7 @@ def setUpClass(cls): @classmethod def conf(cls): conf = super().conf() - conf.set("spark.connect.copyFromLocalToFs.allowDestLocal", "true") + conf.set("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true") return conf def test_basic_requests(self): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2d67a8428d22..6a8e1f92fc51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4550,6 +4550,23 @@ object SQLConf { .booleanConf .createWithDefault(false) + // Deprecate "spark.connect.copyFromLocalToFs.allowDestLocal" in favor of this config. This is + // currently optional because we don't want to break existing users who are using the old config. + // If this config is set, then we override the deprecated config. + val ARTIFACT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL = + buildConf("spark.sql.artifact.copyFromLocalToFs.allowDestLocal") + .internal() + .doc(""" + |Allow `spark.copyFromLocalToFs` destination to be local file system + | path on spark driver node when + |`spark.sql.artifact.copyFromLocalToFs.allowDestLocal` is true. + |This will allow user to overwrite arbitrary file on spark + |driver node we should only enable it for testing purpose. + |""".stripMargin) + .version("4.0.0") + .booleanConf + .createOptional + val LEGACY_RETAIN_FRACTION_DIGITS_FIRST = buildConf("spark.sql.legacy.decimal.retainFractionDigitsOnTruncate") .internal() @@ -4617,7 +4634,9 @@ object SQLConf { DeprecatedConfig(COALESCE_PARTITIONS_MIN_PARTITION_NUM.key, "3.2", s"Use '${COALESCE_PARTITIONS_MIN_PARTITION_SIZE.key}' instead."), DeprecatedConfig(ESCAPED_STRING_LITERALS.key, "4.0", - "Use raw string literals with the `r` prefix instead. ") + "Use raw string literals with the `r` prefix instead. "), + DeprecatedConfig("spark.connect.copyFromLocalToFs.allowDestLocal", "4.0", + s"Use '${ARTIFACT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL.key}' instead.") ) Map(configs.map { cfg => cfg.key -> cfg } : _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 779015ee13eb..5eba9e59c17b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -33,6 +33,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{ConfigEntry, EXECUTOR_ALLOW_SPARK_CONTEXT} import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} +import org.apache.spark.sql.artifact.ArtifactManager import org.apache.spark.sql.catalog.Catalog import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, PosParameterizedQuery, UnresolvedRelation} @@ -243,6 +244,16 @@ class SparkSession private( @Unstable def streams: StreamingQueryManager = sessionState.streamingQueryManager + /** + * Returns an `ArtifactManager` that supports adding, managing and using session-scoped artifacts + * (jars, classfiles, etc). + * + * @since 4.0.0 + */ + @Experimental + @Unstable + private[sql] def artifactManager: ArtifactManager = sessionState.artifactManager + /** * Start a new session with isolated SQL configurations, temporary tables, registered * functions are isolated, but sharing the underlying `SparkContext` and cached data. diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala similarity index 58% rename from connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala rename to sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala index ba36b708e83a..69a5fd860740 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.connect.artifact +package org.apache.spark.sql.artifact import java.io.File import java.net.{URI, URL, URLClassLoader} @@ -26,62 +26,85 @@ import javax.ws.rs.core.UriBuilder import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag -import io.grpc.Status import org.apache.commons.io.{FilenameUtils, FileUtils} import org.apache.hadoop.fs.{LocalFileSystem, Path => FSPath} -import org.apache.spark.{JobArtifactSet, JobArtifactState, SparkContext, SparkEnv} +import org.apache.spark.{JobArtifactSet, JobArtifactState, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{CONNECT_SCALA_UDF_STUB_PREFIXES, EXECUTOR_USER_CLASS_PATH_FIRST} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connect.artifact.util.ArtifactUtils -import org.apache.spark.sql.connect.config.Connect.CONNECT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL -import org.apache.spark.sql.connect.service.SessionHolder +import org.apache.spark.sql.artifact.util.ArtifactUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.storage.{CacheId, StorageLevel} import org.apache.spark.util.{ChildFirstURLClassLoader, StubClassLoader, Utils} /** - * The Artifact Manager for the [[SparkConnectService]]. - * * This class handles the storage of artifacts as well as preparing the artifacts for use. * - * Artifacts belonging to different [[SparkSession]]s are segregated and isolated from each other - * with the help of the `sessionUUID`. + * Artifacts belonging to different SparkSessions are isolated from each other with the help of the + * `sessionUUID`. * - * Jars and classfile artifacts are stored under "jars" and "classes" sub-directories respectively - * while other types of artifacts are stored under the root directory for that particular - * [[SparkSession]]. + * Jars and classfile artifacts are stored under "jars", "classes" and "pyfiles" sub-directories + * respectively while other types of artifacts are stored under the root directory for that + * particular SparkSession. * - * @param sessionHolder - * The object used to hold the Spark Connect session state. + * @param session The object used to hold the Spark Connect session state. */ -class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging { - import SparkConnectArtifactManager._ +class ArtifactManager(session: SparkSession) extends Logging { + import ArtifactManager._ + + // The base directory where all artifacts are stored. + protected def artifactRootPath: Path = artifactRootDirectory + + private[artifact] lazy val artifactRootURI: String = SparkEnv + .get + .rpcEnv + .fileServer + .addDirectoryIfAbsent(ARTIFACT_DIRECTORY_PREFIX, artifactRootPath.toFile) // The base directory/URI where all artifacts are stored for this `sessionUUID`. - val (artifactPath, artifactURI): (Path, String) = - getArtifactDirectoryAndUriForSession(sessionHolder) + protected[artifact] val (artifactPath, artifactURI): (Path, String) = + (ArtifactUtils.concatenatePaths(artifactRootPath, session.sessionUUID), + s"$artifactRootURI/${session.sessionUUID}") + // The base directory/URI where all class file artifacts are stored for this `sessionUUID`. - val (classDir, classURI): (Path, String) = getClassfileDirectoryAndUriForSession(sessionHolder) - val state: JobArtifactState = - JobArtifactState(sessionHolder.serverSessionId, Option(classURI)) + protected[artifact] val (classDir, classURI): (Path, String) = + (ArtifactUtils.concatenatePaths(artifactPath, "classes"), s"$artifactURI/classes/") + + protected[artifact] val state: JobArtifactState = + JobArtifactState(session.sessionUUID, Option(classURI)) + + def withResources[T](f: => T): T = { + Utils.withContextClassLoader(classloader) { + JobArtifactSet.withActiveJobArtifactState(state) { + f + } + } + } - private val jarsList = new CopyOnWriteArrayList[Path] - private val pythonIncludeList = new CopyOnWriteArrayList[String] + protected val jarsList = new CopyOnWriteArrayList[Path] + protected val pythonIncludeList = new CopyOnWriteArrayList[String] /** - * Get the URLs of all jar artifacts added through the [[SparkConnectService]]. - * - * @return + * Get the URLs of all jar artifacts. */ - def getSparkConnectAddedJars: Seq[URL] = jarsList.asScala.map(_.toUri.toURL).toSeq + def getAddedJars: Seq[URL] = jarsList.asScala.map(_.toUri.toURL).toSeq /** - * Get the py-file names added through the [[SparkConnectService]]. + * Get the py-file names added to this SparkSession. * * @return */ - def getSparkConnectPythonIncludes: Seq[String] = pythonIncludeList.asScala.toSeq + def getPythonIncludes: Seq[String] = pythonIncludeList.asScala.toSeq + + protected def moveFile(source: Path, target: Path, allowOverwrite: Boolean = false): Unit = { + Files.createDirectories(target.getParent) + if (allowOverwrite) { + Files.move(source, target, StandardCopyOption.REPLACE_EXISTING) + } else { + Files.move(source, target) + } + } /** * Add and prepare a staged artifact (i.e an artifact that has been rebuilt locally from bytes @@ -91,7 +114,7 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging * @param serverLocalStagingPath * @param fragment */ - private[connect] def addArtifact( + def addArtifact( remoteRelativePath: Path, serverLocalStagingPath: Path, fragment: Option[String]): Unit = JobArtifactSet.withActiveJobArtifactState(state) { @@ -99,10 +122,9 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging if (remoteRelativePath.startsWith(s"cache${File.separator}")) { val tmpFile = serverLocalStagingPath.toFile Utils.tryWithSafeFinallyAndFailureCallbacks { - val blockManager = sessionHolder.session.sparkContext.env.blockManager + val blockManager = session.sparkContext.env.blockManager val blockId = CacheId( - userId = sessionHolder.userId, - sessionId = sessionHolder.sessionId, + sessionUUID = session.sessionUUID, hash = remoteRelativePath.toString.stripPrefix(s"cache${File.separator}")) val updater = blockManager.TempFileBasedBlockStoreUpdater( blockId = blockId, @@ -118,15 +140,12 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging val target = ArtifactUtils.concatenatePaths( classDir, remoteRelativePath.toString.stripPrefix(s"classes${File.separator}")) - Files.createDirectories(target.getParent) // Allow overwriting class files to capture updates to classes. // This is required because the client currently sends all the class files in each class file // transfer. - Files.move(serverLocalStagingPath, target, StandardCopyOption.REPLACE_EXISTING) + moveFile(serverLocalStagingPath, target, allowOverwrite = true) } else { val target = ArtifactUtils.concatenatePaths(artifactPath, remoteRelativePath) - Files.createDirectories(target.getParent) - // Disallow overwriting with modified version if (Files.exists(target)) { // makes the query idempotent @@ -134,22 +153,20 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging return } - throw Status.ALREADY_EXISTS - .withDescription(s"Duplicate Artifact: $remoteRelativePath. " + + throw new RuntimeException(s"Duplicate Artifact: $remoteRelativePath. " + "Artifacts cannot be overwritten.") - .asRuntimeException() } - Files.move(serverLocalStagingPath, target) + moveFile(serverLocalStagingPath, target) // This URI is for Spark file server that starts with "spark://". val uri = s"$artifactURI/${Utils.encodeRelativeUnixPathToURIRawPath( FilenameUtils.separatorsToUnix(remoteRelativePath.toString))}" if (remoteRelativePath.startsWith(s"jars${File.separator}")) { - sessionHolder.session.sparkContext.addJar(uri) + session.sparkContext.addJar(uri) jarsList.add(target) } else if (remoteRelativePath.startsWith(s"pyfiles${File.separator}")) { - sessionHolder.session.sparkContext.addFile(uri) + session.sparkContext.addFile(uri) val stringRemotePath = remoteRelativePath.toString if (stringRemotePath.endsWith(".zip") || stringRemotePath.endsWith( ".egg") || stringRemotePath.endsWith(".jar")) { @@ -158,9 +175,9 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging } else if (remoteRelativePath.startsWith(s"archives${File.separator}")) { val canonicalUri = fragment.map(UriBuilder.fromUri(new URI(uri)).fragment).getOrElse(new URI(uri)) - sessionHolder.session.sparkContext.addArchive(canonicalUri.toString) + session.sparkContext.addArchive(canonicalUri.toString) } else if (remoteRelativePath.startsWith(s"files${File.separator}")) { - sessionHolder.session.sparkContext.addFile(uri) + session.sparkContext.addFile(uri) } } } @@ -169,7 +186,7 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging * Returns a [[ClassLoader]] for session-specific jar/class file resources. */ def classloader: ClassLoader = { - val urls = getSparkConnectAddedJars :+ classDir.toUri.toURL + val urls = getAddedJars :+ classDir.toUri.toURL val prefixes = SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_PREFIXES) val userClasspathFirst = SparkEnv.get.conf.get(EXECUTOR_USER_CLASS_PATH_FIRST) val loader = if (prefixes.nonEmpty) { @@ -208,35 +225,34 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging } /** - * Cleans up all resources specific to this `sessionHolder`. + * Cleans up all resources specific to this `session`. */ - private[connect] def cleanUpResources(): Unit = { + private[sql] def cleanUpResources(): Unit = { logDebug( - s"Cleaning up resources for session with userId: ${sessionHolder.userId} and " + - s"sessionId: ${sessionHolder.sessionId}") + s"Cleaning up resources for session with sessionUUID ${session.sessionUUID}") // Clean up added files val fileserver = SparkEnv.get.rpcEnv.fileServer - val sparkContext = sessionHolder.session.sparkContext + val sparkContext = session.sparkContext sparkContext.addedFiles.remove(state.uuid).foreach(_.keys.foreach(fileserver.removeFile)) sparkContext.addedArchives.remove(state.uuid).foreach(_.keys.foreach(fileserver.removeFile)) sparkContext.addedJars.remove(state.uuid).foreach(_.keys.foreach(fileserver.removeJar)) // Clean up cached relations val blockManager = sparkContext.env.blockManager - blockManager.removeCache(sessionHolder.userId, sessionHolder.sessionId) + blockManager.removeCache(session.sessionUUID) // Clean up artifacts folder FileUtils.deleteDirectory(artifactPath.toFile) } - private[connect] def uploadArtifactToFs( + def uploadArtifactToFs( remoteRelativePath: Path, serverLocalStagingPath: Path): Unit = { - val hadoopConf = sessionHolder.session.sparkContext.hadoopConfiguration + val hadoopConf = session.sparkContext.hadoopConfiguration assert( remoteRelativePath.startsWith( - SparkConnectArtifactManager.forwardToFSPrefix + File.separator)) + ArtifactManager.forwardToFSPrefix + File.separator)) val destFSPath = new FSPath( Paths .get("/") @@ -246,14 +262,17 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging val fs = destFSPath.getFileSystem(hadoopConf) if (fs.isInstanceOf[LocalFileSystem]) { val allowDestLocalConf = - SparkEnv.get.conf.get(CONNECT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL) + session.conf.get(SQLConf.ARTIFACT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL) + .getOrElse( + session.conf.get("spark.connect.copyFromLocalToFs.allowDestLocal").contains("true")) + if (!allowDestLocalConf) { // To avoid security issue, by default, // we don't support uploading file to local file system // destination path, otherwise user is able to overwrite arbitrary file // on spark driver node. // We can temporarily allow the behavior by setting spark config - // `spark.connect.copyFromLocalToFs.allowDestLocal` + // `spark.sql.artifact.copyFromLocalToFs.allowDestLocal` // to `true` when starting spark driver, we should only enable it for testing // purpose. throw new UnsupportedOperationException( @@ -264,80 +283,12 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging } } -object SparkConnectArtifactManager extends Logging { +object ArtifactManager extends Logging { val forwardToFSPrefix = "forward_to_fs" - private var currentArtifactRootUri: String = _ - private var lastKnownSparkContextInstance: SparkContext = _ - - private val ARTIFACT_DIRECTORY_PREFIX = "artifacts" + val ARTIFACT_DIRECTORY_PREFIX = "artifacts" - // The base directory where all artifacts are stored. - private[spark] lazy val artifactRootPath = { + private[artifact] lazy val artifactRootDirectory = Utils.createTempDir(ARTIFACT_DIRECTORY_PREFIX).toPath - } - - private[spark] def getArtifactDirectoryAndUriForSession(session: SparkSession): (Path, String) = - ( - ArtifactUtils.concatenatePaths(artifactRootPath, session.sessionUUID), - s"$artifactRootURI/${session.sessionUUID}") - - private[spark] def getArtifactDirectoryAndUriForSession( - sessionHolder: SessionHolder): (Path, String) = - getArtifactDirectoryAndUriForSession(sessionHolder.session) - - private[spark] def getClassfileDirectoryAndUriForSession( - session: SparkSession): (Path, String) = { - val (artDir, artUri) = getArtifactDirectoryAndUriForSession(session) - (ArtifactUtils.concatenatePaths(artDir, "classes"), s"$artUri/classes/") - } - - private[spark] def getClassfileDirectoryAndUriForSession( - sessionHolder: SessionHolder): (Path, String) = - getClassfileDirectoryAndUriForSession(sessionHolder.session) - - /** - * Updates the URI for the artifact directory. - * - * This is required if the SparkContext is restarted. - * - * Note: This logic is solely to handle testing where a [[SparkContext]] may be restarted - * several times in a single JVM lifetime. In a general Spark cluster, the [[SparkContext]] is - * not expected to be restarted at any point in time. - */ - private def refreshArtifactUri(sc: SparkContext): Unit = synchronized { - // If a competing thread had updated the URI, we do not need to refresh the URI again. - if (sc eq lastKnownSparkContextInstance) { - return - } - val oldArtifactUri = currentArtifactRootUri - currentArtifactRootUri = SparkEnv.get.rpcEnv.fileServer - .addDirectoryIfAbsent(ARTIFACT_DIRECTORY_PREFIX, artifactRootPath.toFile) - lastKnownSparkContextInstance = sc - logDebug(s"Artifact URI updated from $oldArtifactUri to $currentArtifactRootUri") - } - - /** - * Checks if the URI for the artifact directory needs to be updated. This is required in cases - * where SparkContext is restarted as the old URI would no longer be valid. - * - * Note: This logic is solely to handle testing where a [[SparkContext]] may be restarted - * several times in a single JVM lifetime. In a general Spark cluster, the [[SparkContext]] is - * not expected to be restarted at any point in time. - */ - private def updateUriIfRequired(): Unit = { - SparkContext.getActive.foreach { sc => - if (lastKnownSparkContextInstance == null || (sc ne lastKnownSparkContextInstance)) { - logDebug("Refreshing artifact URI due to SparkContext (re)initialisation!") - refreshArtifactUri(sc) - } - } - } - - private[connect] def artifactRootURI: String = { - updateUriIfRequired() - require(currentArtifactRootUri != null) - currentArtifactRootUri - } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/util/ArtifactUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/artifact/util/ArtifactUtils.scala similarity index 88% rename from connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/util/ArtifactUtils.scala rename to sql/core/src/main/scala/org/apache/spark/sql/artifact/util/ArtifactUtils.scala index ab1c0f816594..f16d01501d7c 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/util/ArtifactUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/artifact/util/ArtifactUtils.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.connect.artifact.util +package org.apache.spark.sql.artifact.util import java.nio.file.{Path, Paths} object ArtifactUtils { - private[connect] def concatenatePaths(basePath: Path, otherPath: Path): Path = { + private[sql] def concatenatePaths(basePath: Path, otherPath: Path): Path = { require(!otherPath.isAbsolute) // We avoid using the `.resolve()` method here to ensure that we're concatenating the two // paths. @@ -37,7 +37,7 @@ object ArtifactUtils { normalizedPath } - private[connect] def concatenatePaths(basePath: Path, otherPath: String): Path = { + private[sql] def concatenatePaths(basePath: Path, otherPath: String): Path = { concatenatePaths(basePath, Paths.get(otherPath)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 1d496b027ef5..630e1202f6d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.internal import org.apache.spark.annotation.Unstable import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _} +import org.apache.spark.sql.artifact.ArtifactManager import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, FunctionRegistry, ReplaceCharWithVarchar, ResolveSessionCatalog, TableFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{FunctionExpressionBuilder, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.Expression @@ -349,6 +350,12 @@ abstract class BaseSessionStateBuilder( new ExecutionListenerManager(session, conf, loadExtensions = true)) } + /** + * Resource manager that handles the storage of artifacts as well as preparing the artifacts for + * use. + */ + protected def artifactManager: ArtifactManager = new ArtifactManager(session) + /** * Function used to make clones of the session state. */ @@ -381,7 +388,8 @@ abstract class BaseSessionStateBuilder( createClone, columnarRules, adaptiveRulesHolder, - planNormalizationRules) + planNormalizationRules, + () => artifactManager) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 177a25b45fc3..adf3e0cb6cad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Unstable import org.apache.spark.sql._ +import org.apache.spark.sql.artifact.ArtifactManager import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry, TableFunctionRegistry} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.optimizer.Optimizer @@ -84,7 +85,8 @@ private[sql] class SessionState( createClone: (SparkSession, SessionState) => SessionState, val columnarRules: Seq[ColumnarRule], val adaptiveRulesHolder: AdaptiveRulesHolder, - val planNormalizationRules: Seq[Rule[LogicalPlan]]) { + val planNormalizationRules: Seq[Rule[LogicalPlan]], + val artifactManagerBuilder: () => ArtifactManager) { // The following fields are lazy to avoid creating the Hive client when creating SessionState. lazy val catalog: SessionCatalog = catalogBuilder() @@ -99,6 +101,8 @@ private[sql] class SessionState( // when connecting to ThriftServer. lazy val streamingQueryManager: StreamingQueryManager = streamingQueryManagerBuilder() + lazy val artifactManager: ArtifactManager = artifactManagerBuilder() + def catalogManager: CatalogManager = analyzer.catalogManager def newHadoopConf(): Configuration = SessionState.newHadoopConf( diff --git a/sql/core/src/test/resources/artifact-tests/Hello.class b/sql/core/src/test/resources/artifact-tests/Hello.class new file mode 100644 index 000000000000..56725764de20 Binary files /dev/null and b/sql/core/src/test/resources/artifact-tests/Hello.class differ diff --git a/sql/core/src/test/resources/artifact-tests/smallClassFile.class b/sql/core/src/test/resources/artifact-tests/smallClassFile.class new file mode 100755 index 000000000000..e796030e471b Binary files /dev/null and b/sql/core/src/test/resources/artifact-tests/smallClassFile.class differ diff --git a/sql/core/src/test/resources/artifact-tests/udf_noA.jar b/sql/core/src/test/resources/artifact-tests/udf_noA.jar new file mode 100644 index 000000000000..4d8c423ab6df Binary files /dev/null and b/sql/core/src/test/resources/artifact-tests/udf_noA.jar differ diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala similarity index 68% rename from connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala index 0c095384de86..263006100bea 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala @@ -14,37 +14,31 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.connect.artifact +package org.apache.spark.sql.artifact +import java.io.File import java.nio.charset.StandardCharsets import java.nio.file.{Files, Paths} -import java.util.UUID -import io.grpc.StatusRuntimeException import org.apache.commons.io.FileUtils -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} -import org.apache.spark.sql.connect.ResourceHelper -import org.apache.spark.sql.connect.service.{SessionHolder, SparkConnectService} +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions.col import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.storage.CacheId import org.apache.spark.util.Utils -class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { +class ArtifactManagerSuite extends SharedSparkSession { override protected def sparkConf: SparkConf = { val conf = super.sparkConf - conf - .set("spark.plugins", "org.apache.spark.sql.connect.SparkConnectPlugin") - .set("spark.connect.copyFromLocalToFs.allowDestLocal", "true") + conf.set("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true") } - private val artifactPath = commonResourcePath.resolve("artifact-tests") - private lazy val sessionHolder: SessionHolder = { - SessionHolder("test", UUID.randomUUID().toString, spark) - } - private lazy val artifactManager = new SparkConnectArtifactManager(sessionHolder) + private val artifactPath = new File("src/test/resources/artifact-tests").toPath + + private lazy val artifactManager = spark.artifactManager private def sessionUUID: String = spark.sessionUUID @@ -61,7 +55,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { assert(stagingPath.toFile.exists()) artifactManager.addArtifact(remotePath, stagingPath, None) - val movedClassFile = SparkConnectArtifactManager.artifactRootPath + val movedClassFile = ArtifactManager.artifactRootDirectory .resolve(s"$sessionUUID/classes/smallClassFile.class") .toFile assert(movedClassFile.exists()) @@ -75,7 +69,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { assert(stagingPath.toFile.exists()) artifactManager.addArtifact(remotePath, stagingPath, None) - val movedClassFile = SparkConnectArtifactManager.artifactRootPath + val movedClassFile = ArtifactManager.artifactRootDirectory .resolve(s"$sessionUUID/classes/Hello.class") .toFile assert(movedClassFile.exists()) @@ -98,16 +92,14 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val remotePath = Paths.get("classes/Hello.class") assert(stagingPath.toFile.exists()) - val sessionHolder = - SparkConnectService.getOrCreateIsolatedSession("c1", UUID.randomUUID.toString()) - sessionHolder.addArtifact(remotePath, stagingPath, None) + artifactManager.addArtifact(remotePath, stagingPath, None) - val movedClassFile = SparkConnectArtifactManager.artifactRootPath - .resolve(s"${sessionHolder.session.sessionUUID}/classes/Hello.class") + val movedClassFile = ArtifactManager.artifactRootDirectory + .resolve(s"${spark.sessionUUID}/classes/Hello.class") .toFile assert(movedClassFile.exists()) - val classLoader = sessionHolder.classloader + val classLoader = spark.artifactManager.classloader val instance = classLoader .loadClass("Hello") .getDeclaredConstructor(classOf[String]) @@ -115,8 +107,8 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { .asInstanceOf[String => String] val udf = org.apache.spark.sql.functions.udf(instance) - sessionHolder.withSession { session => - session.range(10).select(udf(col("id").cast("string"))).collect() + spark.artifactManager.withResources { + spark.range(10).select(udf(col("id").cast("string"))).collect() } } @@ -125,9 +117,8 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val stagingPath = path.toPath Files.write(path.toPath, "test".getBytes(StandardCharsets.UTF_8)) val remotePath = Paths.get("cache/abc") - val session = sessionHolder val blockManager = spark.sparkContext.env.blockManager - val blockId = CacheId(session.userId, session.sessionId, "abc") + val blockId = CacheId(spark.sessionUUID, "abc") try { artifactManager.addArtifact(remotePath, stagingPath, None) val bytes = blockManager.getLocalBytes(blockId) @@ -136,7 +127,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { assert(readback === "test") } finally { blockManager.releaseLock(blockId) - blockManager.removeCache(session.userId, session.sessionId) + blockManager.removeCache(spark.sessionUUID) } } } @@ -147,7 +138,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { Files.write(path.toPath, "test".getBytes(StandardCharsets.UTF_8)) val remotePath = Paths.get("pyfiles/abc.zip") artifactManager.addArtifact(remotePath, stagingPath, None) - assert(artifactManager.getSparkConnectPythonIncludes == Seq("abc.zip")) + assert(artifactManager.getPythonIncludes == Seq("abc.zip")) } } @@ -167,7 +158,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { withTempPath { path => Files.write(path.toPath, "updated file".getBytes(StandardCharsets.UTF_8)) - assertThrows[StatusRuntimeException] { + assertThrows[RuntimeException] { artifactManager.addArtifact(remotePath, path.toPath, None) } } @@ -193,9 +184,8 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val stagingPath = path.toPath Files.write(path.toPath, "test".getBytes(StandardCharsets.UTF_8)) val remotePath = Paths.get("cache/abc") - val session = sessionHolder val blockManager = spark.sparkContext.env.blockManager - val blockId = CacheId(session.userId, session.sessionId, "abc") + val blockId = CacheId(spark.sessionUUID, "abc") // Setup artifact dir val copyDir = Utils.createTempDir().toPath FileUtils.copyDirectory(artifactPath.toFile, copyDir.toFile) @@ -209,7 +199,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val bytes = blockManager.getLocalBytes(blockId) assert(bytes.isDefined) blockManager.releaseLock(blockId) - val expectedPath = SparkConnectArtifactManager.artifactRootPath + val expectedPath = ArtifactManager.artifactRootDirectory .resolve(s"$sessionUUID/classes/smallClassFile.class") assert(expectedPath.toFile.exists()) @@ -226,32 +216,30 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { case throwable: Throwable => throw throwable } finally { FileUtils.deleteDirectory(copyDir.toFile) - blockManager.removeCache(session.userId, session.sessionId) + blockManager.removeCache(spark.sessionUUID) } } } } test("Classloaders for spark sessions are isolated") { - // use same sessionId - different users should still make it isolated. - val sessionId = UUID.randomUUID.toString() - val holder1 = SparkConnectService.getOrCreateIsolatedSession("c1", sessionId) - val holder2 = SparkConnectService.getOrCreateIsolatedSession("c2", sessionId) - val holder3 = SparkConnectService.getOrCreateIsolatedSession("c3", sessionId) + val session1 = spark.newSession() + val session2 = spark.newSession() + val session3 = spark.newSession() - def addHelloClass(holder: SessionHolder): Unit = { + def addHelloClass(session: SparkSession): Unit = { val copyDir = Utils.createTempDir().toPath FileUtils.copyDirectory(artifactPath.toFile, copyDir.toFile) val stagingPath = copyDir.resolve("Hello.class") val remotePath = Paths.get("classes/Hello.class") assert(stagingPath.toFile.exists()) - holder.addArtifact(remotePath, stagingPath, None) + session.artifactManager.addArtifact(remotePath, stagingPath, None) } // Add the "Hello" classfile for the first user - addHelloClass(holder1) + addHelloClass(session1) - val classLoader1 = holder1.classloader + val classLoader1 = session1.artifactManager.classloader val instance1 = classLoader1 .loadClass("Hello") .getDeclaredConstructor(classOf[String]) @@ -259,13 +247,13 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { .asInstanceOf[String => String] val udf1 = org.apache.spark.sql.functions.udf(instance1) - holder1.withSession { session => - val result = session.range(10).select(udf1(col("id").cast("string"))).collect() - assert(result.forall(_.getString(0).contains("Talon"))) + session1.artifactManager.withResources { + val result1 = session1.range(10).select(udf1(col("id").cast("string"))).collect() + assert(result1.forall(_.getString(0).contains("Talon"))) } assertThrows[ClassNotFoundException] { - val classLoader2 = holder2.classloader + val classLoader2 = session2.artifactManager.classloader val instance2 = classLoader2 .loadClass("Hello") .getDeclaredConstructor(classOf[String]) @@ -274,17 +262,19 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { } // Add the "Hello" classfile for the third user - addHelloClass(holder3) - val instance3 = holder3.classloader + addHelloClass(session3) + + val classLoader3 = session3.artifactManager.classloader + val instance3 = classLoader3 .loadClass("Hello") .getDeclaredConstructor(classOf[String]) .newInstance("Ahri") .asInstanceOf[String => String] val udf3 = org.apache.spark.sql.functions.udf(instance3) - holder3.withSession { session => - val result = session.range(10).select(udf3(col("id").cast("string"))).collect() - assert(result.forall(_.getString(0).contains("Ahri"))) + session3.artifactManager.withResources { + val result3 = session3.range(10).select(udf3(col("id").cast("string"))).collect() + assert(result3.forall(_.getString(0).contains("Ahri"))) } } @@ -294,36 +284,13 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val stagingPath = copyDir.resolve("Hello.class") val remotePath = Paths.get("classes/Hello.class") - val holder = - SparkConnectService.getOrCreateIsolatedSession("c1", UUID.randomUUID.toString) - holder.addArtifact(remotePath, stagingPath, None) + artifactManager.addArtifact(remotePath, stagingPath, None) - val sessionDirectory = - SparkConnectArtifactManager.getArtifactDirectoryAndUriForSession(holder)._1.toFile + val sessionDirectory = artifactManager.artifactPath.toFile assert(sessionDirectory.exists()) - holder.artifactManager.cleanUpResources() + artifactManager.cleanUpResources() assert(!sessionDirectory.exists()) - assert(SparkConnectArtifactManager.artifactRootPath.toFile.exists()) - } -} - -class ArtifactUriSuite extends SparkFunSuite with LocalSparkContext { - - private def createSparkContext(): Unit = { - resetSparkContext() - sc = new SparkContext("local[4]", "test", new SparkConf()) - - } - override def beforeEach(): Unit = { - super.beforeEach() - createSparkContext() - } - - test("Artifact URI is reset when SparkContext is restarted") { - val oldUri = SparkConnectArtifactManager.artifactRootURI - createSparkContext() - val newUri = SparkConnectArtifactManager.artifactRootURI - assert(newUri != oldUri) + assert(ArtifactManager.artifactRootDirectory.toFile.exists()) } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/StubClassLoaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/artifact/StubClassLoaderSuite.scala similarity index 94% rename from connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/StubClassLoaderSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/artifact/StubClassLoaderSuite.scala index bde9a71fa17e..c1a0cc274009 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/StubClassLoaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/artifact/StubClassLoaderSuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.connect.artifact +package org.apache.spark.sql.artifact import java.io.File @@ -23,8 +23,11 @@ import org.apache.spark.util.{ChildFirstURLClassLoader, StubClassLoader} class StubClassLoaderSuite extends SparkFunSuite { - // See src/test/resources/StubClassDummyUdf for how the UDFs and jars are created. - private val udfNoAJar = new File("src/test/resources/udf_noA.jar").toURI.toURL + // TODO: Modify JAR to remove references to connect. + // See connector/client/jvm/src/test/resources/StubClassDummyUdf for how the UDFs and jars are + // created. + private val udfNoAJar = new File( + "src/test/resources/artifact-tests/udf_noA.jar").toURI.toURL private val classDummyUdf = "org.apache.spark.sql.connect.client.StubClassDummyUdf" private val classA = "org.apache.spark.sql.connect.client.A"