diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index cecf14a70451..ace981c3d826 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -26,8 +26,6 @@ import com.google.protobuf.{Any => ProtoAny, ByteString} import io.grpc.{Context, Status, StatusRuntimeException} import io.grpc.stub.StreamObserver import org.apache.commons.lang3.exception.ExceptionUtils -import org.json4s._ -import org.json4s.jackson.JsonMethods.parse import org.apache.spark.{Partition, SparkEnv, TaskContext} import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction} @@ -91,15 +89,6 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { private lazy val pythonExec = sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3")) - // SparkConnectPlanner is used per request. - private lazy val pythonIncludes = { - implicit val formats = DefaultFormats - parse(session.conf.get("spark.connect.pythonUDF.includes", "[]")) - .extract[Array[String]] - .toList - .asJava - } - // The root of the query plan is a relation and we apply the transformations to it. def transformRelation(rel: proto.Relation): LogicalPlan = { val plan = rel.getRelTypeCase match { @@ -1519,7 +1508,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { command = fun.getCommand.toByteArray, // Empty environment variables envVars = Maps.newHashMap(), - pythonIncludes = pythonIncludes, + pythonIncludes = sessionHolder.artifactManager.getSparkConnectPythonIncludes.asJava, pythonExec = pythonExec, pythonVer = fun.getPythonVer, // Empty broadcast variables diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 004322097790..0152f980f15d 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -24,9 +24,6 @@ import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap} import scala.collection.JavaConverters._ import scala.util.control.NonFatal -import org.json4s.JsonDSL._ -import org.json4s.jackson.JsonMethods.{compact, render} - import org.apache.spark.JobArtifactSet import org.apache.spark.connect.proto import org.apache.spark.internal.Logging @@ -107,7 +104,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio * @param f * @tparam T */ - def withContext[T](f: => T): T = { + def withContextClassLoader[T](f: => T): T = { // Needed for deserializing and evaluating the UDF on the driver Utils.withContextClassLoader(classloader) { // Needed for propagating the dependencies to the executors. @@ -117,49 +114,15 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio } } - /** - * Set the session-based Python paths to include in Python UDF. - * @param f - * @tparam T - */ - def withSessionBasedPythonPaths[T](f: => T): T = { - try { - session.conf.set( - "spark.connect.pythonUDF.includes", - compact(render(artifactManager.getSparkConnectPythonIncludes))) - f - } finally { - session.conf.unset("spark.connect.pythonUDF.includes") - } - } - /** * Execute a block of code with this session as the active SparkConnect session. * @param f * @tparam T */ def withSession[T](f: SparkSession => T): T = { - withSessionBasedPythonPaths { - withContext { - session.withActive { - f(session) - } - } - } - } - - /** - * Execute a block of code using the session from this [[SessionHolder]] as the active - * SparkConnect session. - * @param f - * @tparam T - */ - def withSessionHolder[T](f: SessionHolder => T): T = { - withSessionBasedPythonPaths { - withContext { - session.withActive { - f(this) - } + withContextClassLoader { + session.withActive { + f(session) } } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala index 5c069bfaf5d0..414a852380fd 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala @@ -38,7 +38,7 @@ private[connect] class SparkConnectAnalyzeHandler( request.getSessionId) // `withSession` ensures that session-specific artifacts (such as JARs and class files) are // available during processing (such as deserialization). - sessionHolder.withSessionHolder { sessionHolder => + sessionHolder.withSession { _ => val response = process(request, sessionHolder) responseObserver.onNext(response) responseObserver.onCompleted() diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala index 42ab8ca18f6e..612bf096b22b 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala @@ -224,6 +224,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { test("Classloaders for spark sessions are isolated") { val holder1 = SparkConnectService.getOrCreateIsolatedSession("c1", "session1") val holder2 = SparkConnectService.getOrCreateIsolatedSession("c2", "session2") + val holder3 = SparkConnectService.getOrCreateIsolatedSession("c3", "session3") def addHelloClass(holder: SessionHolder): Unit = { val copyDir = Utils.createTempDir().toPath @@ -234,7 +235,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { holder.addArtifact(remotePath, stagingPath, None) } - // Add the classfile only for the first user + // Add the "Hello" classfile for the first user addHelloClass(holder1) val classLoader1 = holder1.classloader @@ -246,7 +247,8 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val udf1 = org.apache.spark.sql.functions.udf(instance1) holder1.withSession { session => - session.range(10).select(udf1(col("id").cast("string"))).collect() + val result = session.range(10).select(udf1(col("id").cast("string"))).collect() + assert(result.forall(_.getString(0).contains("Talon"))) } assertThrows[ClassNotFoundException] { @@ -257,6 +259,20 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { .newInstance("Talon") .asInstanceOf[String => String] } + + // Add the "Hello" classfile for the third user + addHelloClass(holder3) + val instance3 = holder3.classloader + .loadClass("Hello") + .getDeclaredConstructor(classOf[String]) + .newInstance("Ahri") + .asInstanceOf[String => String] + val udf3 = org.apache.spark.sql.functions.udf(instance3) + + holder3.withSession { session => + val result = session.range(10).select(udf3(col("id").cast("string"))).collect() + assert(result.forall(_.getString(0).contains("Ahri"))) + } } } diff --git a/core/src/test/resources/TestHelloV2.jar b/core/src/test/resources/TestHelloV2.jar new file mode 100644 index 000000000000..d89cf6543a20 Binary files /dev/null and b/core/src/test/resources/TestHelloV2.jar differ diff --git a/core/src/test/resources/TestHelloV3.jar b/core/src/test/resources/TestHelloV3.jar new file mode 100644 index 000000000000..b175a6c86407 Binary files /dev/null and b/core/src/test/resources/TestHelloV3.jar differ diff --git a/core/src/test/scala/org/apache/spark/executor/ClassLoaderIsolationSuite.scala b/core/src/test/scala/org/apache/spark/executor/ClassLoaderIsolationSuite.scala new file mode 100644 index 000000000000..33c1baccd729 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/executor/ClassLoaderIsolationSuite.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import org.apache.spark.{JobArtifactSet, LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.util.Utils + +class ClassLoaderIsolationSuite extends SparkFunSuite with LocalSparkContext { + val jar1 = Thread.currentThread().getContextClassLoader.getResource("TestUDTF.jar").toString + + // package com.example + // object Hello { def test(): Int = 2 } + // case class Hello(x: Int, y: Int) + val jar2 = Thread.currentThread().getContextClassLoader.getResource("TestHelloV2.jar").toString + + // package com.example + // object Hello { def test(): Int = 3 } + // case class Hello(x: String) + val jar3 = Thread.currentThread().getContextClassLoader.getResource("TestHelloV3.jar").toString + + test("Executor classloader isolation with JobArtifactSet") { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + sc.addJar(jar1) + sc.addJar(jar2) + sc.addJar(jar3) + + // TestHelloV2's test method returns '2' + val artifactSetWithHelloV2 = new JobArtifactSet( + uuid = Some("hello2"), + replClassDirUri = None, + jars = Map(jar2 -> 1L), + files = Map.empty, + archives = Map.empty + ) + + JobArtifactSet.withActive(artifactSetWithHelloV2) { + sc.parallelize(1 to 1).foreach { i => + val cls = Utils.classForName("com.example.Hello$") + val module = cls.getField("MODULE$").get(null) + val result = cls.getMethod("test").invoke(module).asInstanceOf[Int] + if (result != 2) { + throw new RuntimeException("Unexpected result: " + result) + } + } + } + + // TestHelloV3's test method returns '3' + val artifactSetWithHelloV3 = new JobArtifactSet( + uuid = Some("hello3"), + replClassDirUri = None, + jars = Map(jar3 -> 1L), + files = Map.empty, + archives = Map.empty + ) + + JobArtifactSet.withActive(artifactSetWithHelloV3) { + sc.parallelize(1 to 1).foreach { i => + val cls = Utils.classForName("com.example.Hello$") + val module = cls.getField("MODULE$").get(null) + val result = cls.getMethod("test").invoke(module).asInstanceOf[Int] + if (result != 3) { + throw new RuntimeException("Unexpected result: " + result) + } + } + } + + // Should not be able to see any "Hello" class if they're excluded from the artifact set + val artifactSetWithoutHello = new JobArtifactSet( + uuid = Some("Jar 1"), + replClassDirUri = None, + jars = Map(jar1 -> 1L), + files = Map.empty, + archives = Map.empty + ) + + JobArtifactSet.withActive(artifactSetWithoutHello) { + sc.parallelize(1 to 1).foreach { i => + try { + Utils.classForName("com.example.Hello$") + throw new RuntimeException("Import should fail") + } catch { + case _: ClassNotFoundException => + } + } + } + } +}