Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,17 @@ object CheckConnectJvmClientCompatibility {

// SQLImplicits
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SQLImplicits.rddToDatasetHolder"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SQLImplicits._sqlContext"))
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SQLImplicits._sqlContext"),

// Artifact Manager
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.artifact.ArtifactManager"),
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.artifact.ArtifactManager$"),
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.artifact.util.ArtifactUtils"),
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.artifact.util.ArtifactUtils$"))
checkMiMaCompatibility(clientJar, sqlJar, includedRules, excludeRules)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import java.util.concurrent.TimeUnit

import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.connect.common.config.ConnectCommon
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.buildConf

object Connect {
Expand Down Expand Up @@ -206,27 +207,24 @@ object Connect {
.intConf
.createWithDefault(1024)

val CONNECT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL =
buildStaticConf("spark.connect.copyFromLocalToFs.allowDestLocal")
.internal()
.doc("""
|Allow `spark.copyFromLocalToFs` destination to be local file system
| path on spark driver node when
|`spark.connect.copyFromLocalToFs.allowDestLocal` is true.
|This will allow user to overwrite arbitrary file on spark
|driver node we should only enable it for testing purpose.
|""".stripMargin)
.version("3.5.0")
.booleanConf
.createWithDefault(false)

val CONNECT_UI_STATEMENT_LIMIT =
buildStaticConf("spark.sql.connect.ui.retainedStatements")
.doc("The number of statements kept in the Spark Connect UI history.")
.version("3.5.0")
.intConf
.createWithDefault(200)

val CONNECT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL =
buildStaticConf("spark.connect.copyFromLocalToFs.allowDestLocal")
.internal()
.doc(s"""
|(Deprecated since Spark 4.0, please set
|'${SQLConf.ARTIFACT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL.key}' instead.
|""".stripMargin)
.version("3.5.0")
.booleanConf
.createWithDefault(false)

val CONNECT_UI_SESSION_LIMIT = buildStaticConf("spark.sql.connect.ui.retainedSessions")
.doc("The number of client sessions kept in the Spark Connect UI history.")
.version("3.5.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -942,7 +942,7 @@ class SparkConnectPlanner(
command = fun.getCommand.toByteArray.toImmutableArraySeq,
// Empty environment variables
envVars = Maps.newHashMap(),
pythonIncludes = sessionHolder.artifactManager.getSparkConnectPythonIncludes.asJava,
pythonIncludes = sessionHolder.artifactManager.getPythonIncludes.asJava,
pythonExec = pythonExec,
pythonVer = fun.getPythonVer,
// Empty broadcast variables
Expand Down Expand Up @@ -996,7 +996,7 @@ class SparkConnectPlanner(

private def transformCachedLocalRelation(rel: proto.CachedLocalRelation): LogicalPlan = {
val blockManager = session.sparkContext.env.blockManager
val blockId = CacheId(sessionHolder.userId, sessionHolder.sessionId, rel.getHash)
val blockId = CacheId(sessionHolder.session.sessionUUID, rel.getHash)
val bytes = blockManager.getLocalBytes(blockId)
bytes
.map { blockData =>
Expand All @@ -1014,7 +1014,7 @@ class SparkConnectPlanner(
.getOrElse {
throw InvalidPlanInput(
s"Not found any cached local relation with the hash: ${blockId.hash} in " +
s"the session ${blockId.sessionId} for the user id ${blockId.userId}.")
s"the session with sessionUUID ${blockId.sessionUUID}.")
}
}

Expand Down Expand Up @@ -1633,7 +1633,7 @@ class SparkConnectPlanner(
command = fun.getCommand.toByteArray.toImmutableArraySeq,
// Empty environment variables
envVars = Maps.newHashMap(),
pythonIncludes = sessionHolder.artifactManager.getSparkConnectPythonIncludes.asJava,
pythonIncludes = sessionHolder.artifactManager.getPythonIncludes.asJava,
pythonExec = pythonExec,
pythonVer = fun.getPythonVer,
// Empty broadcast variables
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,16 @@ import scala.jdk.CollectionConverters._
import com.google.common.base.Ticker
import com.google.common.cache.CacheBuilder

import org.apache.spark.{JobArtifactSet, SparkException, SparkSQLException}
import org.apache.spark.{SparkException, SparkSQLException}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager
import org.apache.spark.sql.connect.common.InvalidPlanInput
import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener
import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper
import org.apache.spark.sql.connect.service.SessionHolder.{ERROR_CACHE_SIZE, ERROR_CACHE_TIMEOUT_SEC}
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.util.SystemClock
import org.apache.spark.util.Utils

// Unique key identifying session by combination of user, and session id
case class SessionKey(userId: String, sessionId: String)
Expand Down Expand Up @@ -166,7 +164,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
interruptedIds.toSeq
}

private[connect] lazy val artifactManager = new SparkConnectArtifactManager(this)
private[connect] def artifactManager = session.artifactManager

/**
* Add an artifact to this SparkConnect session.
Expand Down Expand Up @@ -238,27 +236,13 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
eventManager.postClosed()
}

/**
* Execute a block of code using this session's classloader.
* @param f
* @tparam T
*/
def withContextClassLoader[T](f: => T): T = {
// Needed for deserializing and evaluating the UDF on the driver
Utils.withContextClassLoader(classloader) {
JobArtifactSet.withActiveJobArtifactState(artifactManager.state) {
f
}
}
}

/**
* Execute a block of code with this session as the active SparkConnect session.
* @param f
* @tparam T
*/
def withSession[T](f: SparkSession => T): T = {
withContextClassLoader {
artifactManager.withResources {
session.withActive {
f(session)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ import io.grpc.stub.StreamObserver
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse}
import org.apache.spark.connect.proto.AddArtifactsResponse.ArtifactSummary
import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager
import org.apache.spark.sql.connect.artifact.util.ArtifactUtils
import org.apache.spark.sql.artifact.ArtifactManager
import org.apache.spark.sql.artifact.util.ArtifactUtils
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -101,8 +101,7 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr
// We do not store artifacts that fail the CRC. The failure is reported in the artifact
// summary and it is up to the client to decide whether to retry sending the artifact.
if (artifact.getCrcStatus.contains(true)) {
if (artifact.path.startsWith(
SparkConnectArtifactManager.forwardToFSPrefix + File.separator)) {
if (artifact.path.startsWith(ArtifactManager.forwardToFSPrefix + File.separator)) {
holder.artifactManager.uploadArtifactToFs(artifact.path, artifact.stagedPath)
} else {
addStagedArtifactToArtifactManager(artifact)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class SparkConnectArtifactStatusesHandler(
.getOrCreateIsolatedSession(userId, sessionId)
.session
val blockManager = session.sparkContext.env.blockManager
blockManager.getStatus(CacheId(userId, sessionId, hash)).isDefined
blockManager.getStatus(CacheId(session.sessionUUID, hash)).isDefined
}

def handle(request: proto.ArtifactStatusesRequest): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession {
SimplePythonFunction(
command = fcn(sparkPythonPath).toImmutableArraySeq,
envVars = mutable.Map("PYTHONPATH" -> sparkPythonPath).asJava,
pythonIncludes = sessionHolder.artifactManager.getSparkConnectPythonIncludes.asJava,
pythonIncludes = sessionHolder.artifactManager.getPythonIncludes.asJava,
pythonExec = IntegratedUDFTestUtils.pythonExec,
pythonVer = IntegratedUDFTestUtils.pythonVer,
broadcastVars = Lists.newArrayList(),
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/storage/BlockId.scala
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ class UnrecognizedBlockId(name: String)
extends SparkException(s"Failed to parse $name into a block ID")

@DeveloperApi
case class CacheId(userId: String, sessionId: String, hash: String) extends BlockId {
override def name: String = s"cache_${userId}_${sessionId}_$hash"
case class CacheId(sessionUUID: String, hash: String) extends BlockId {
override def name: String = s"cache_${sessionUUID}_$hash"
}

@DeveloperApi
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2057,10 +2057,10 @@ private[spark] class BlockManager(
*
* @return The number of blocks removed.
*/
def removeCache(userId: String, sessionId: String): Int = {
logDebug(s"Removing cache of user id = $userId in the session $sessionId")
def removeCache(sessionUUID: String): Int = {
logDebug(s"Removing cache of spark session with UUID: $sessionUUID")
val blocksToRemove = blockInfoManager.entries.map(_._1).collect {
case cid: CacheId if cid.userId == userId && cid.sessionId == sessionId => cid
case cid: CacheId if cid.sessionUUID == sessionUUID => cid
}
blocksToRemove.foreach { blockId => removeBlock(blockId) }
blocksToRemove.size
Expand Down
10 changes: 10 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,16 @@ object MimaExcludes {
// SPARK-43299: Convert StreamingQueryException in Scala Client
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryException"),

// SPARK-45856: Move ArtifactManager from Spark Connect into SparkSession (sql/core)
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.CacheId.apply"),
Copy link
Contributor

@LuciferYang LuciferYang Dec 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vicennial I would like to reconfirm, the ProblemFilters added by SPARK-45856 will never need to undergo a mima check in versions after Spark 4.0, is that correct? Or is this just the ProblemFilters added for the mima check between Spark 4.0 and Spark 3.5?I found that it has been placed in defaultExcludes.

ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.CacheId.userId"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.CacheId.sessionId"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.CacheId.copy"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.CacheId.copy$default$3"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.CacheId.this"),
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.storage.CacheId$"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.CacheId.apply"),

(problem: Problem) => problem match {
case MissingClassProblem(cls) => !cls.fullName.startsWith("org.sparkproject.jpmml") &&
!cls.fullName.startsWith("org.sparkproject.dmg.pmml")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class ClassificationTestsOnConnect(ClassificationTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = (
SparkSession.builder.remote("local[2]")
.config("spark.connect.copyFromLocalToFs.allowDestLocal", "true")
.config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true")
.getOrCreate()
)

Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/ml/tests/connect/test_connect_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class PipelineTestsOnConnect(PipelineTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = (
SparkSession.builder.remote("local[2]")
.config("spark.connect.copyFromLocalToFs.allowDestLocal", "true")
.config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true")
.getOrCreate()
)

Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/ml/tests/connect/test_connect_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class CrossValidatorTestsOnConnect(CrossValidatorTestsMixin, unittest.TestCase):
def setUp(self) -> None:
self.spark = (
SparkSession.builder.remote("local[2]")
.config("spark.connect.copyFromLocalToFs.allowDestLocal", "true")
.config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true")
.getOrCreate()
)

Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests/connect/client/test_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def setUpClass(cls):
@classmethod
def conf(cls):
conf = super().conf()
conf.set("spark.connect.copyFromLocalToFs.allowDestLocal", "true")
conf.set("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true")
return conf

def test_basic_requests(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4550,6 +4550,23 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

// Deprecate "spark.connect.copyFromLocalToFs.allowDestLocal" in favor of this config. This is
// currently optional because we don't want to break existing users who are using the old config.
// If this config is set, then we override the deprecated config.
val ARTIFACT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL =
buildConf("spark.sql.artifact.copyFromLocalToFs.allowDestLocal")
.internal()
.doc("""
|Allow `spark.copyFromLocalToFs` destination to be local file system
| path on spark driver node when
|`spark.sql.artifact.copyFromLocalToFs.allowDestLocal` is true.
|This will allow user to overwrite arbitrary file on spark
|driver node we should only enable it for testing purpose.
|""".stripMargin)
.version("4.0.0")
.booleanConf
.createOptional

val LEGACY_RETAIN_FRACTION_DIGITS_FIRST =
buildConf("spark.sql.legacy.decimal.retainFractionDigitsOnTruncate")
.internal()
Expand Down Expand Up @@ -4617,7 +4634,9 @@ object SQLConf {
DeprecatedConfig(COALESCE_PARTITIONS_MIN_PARTITION_NUM.key, "3.2",
s"Use '${COALESCE_PARTITIONS_MIN_PARTITION_SIZE.key}' instead."),
DeprecatedConfig(ESCAPED_STRING_LITERALS.key, "4.0",
"Use raw string literals with the `r` prefix instead. ")
"Use raw string literals with the `r` prefix instead. "),
DeprecatedConfig("spark.connect.copyFromLocalToFs.allowDestLocal", "4.0",
s"Use '${ARTIFACT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL.key}' instead.")
)

Map(configs.map { cfg => cfg.key -> cfg } : _*)
Expand Down
11 changes: 11 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.{ConfigEntry, EXECUTOR_ALLOW_SPARK_CONTEXT}
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.artifact.ArtifactManager
import org.apache.spark.sql.catalog.Catalog
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, PosParameterizedQuery, UnresolvedRelation}
Expand Down Expand Up @@ -243,6 +244,16 @@ class SparkSession private(
@Unstable
def streams: StreamingQueryManager = sessionState.streamingQueryManager

/**
* Returns an `ArtifactManager` that supports adding, managing and using session-scoped artifacts
* (jars, classfiles, etc).
*
* @since 4.0.0
*/
@Experimental
@Unstable
private[sql] def artifactManager: ArtifactManager = sessionState.artifactManager

/**
* Start a new session with isolated SQL configurations, temporary tables, registered
* functions are isolated, but sharing the underlying `SparkContext` and cached data.
Expand Down
Loading