diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 5a0f33ffd5dc..cc6c8100d3c6 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -554,7 +554,7 @@ class SparkSession private[sql] ( val command = proto.Command.newBuilder().setRegisterFunction(udf).build() val plan = proto.Plan.newBuilder().setCommand(command).build() - client.execute(plan) + client.execute(plan).asScala.foreach(_ => ()) } @DeveloperApi diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 18aef8a2e4cf..e5c89d90c19b 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -92,7 +92,7 @@ sealed abstract class UserDefinedFunction { /** * Holder class for a scalar user-defined function and it's input/output encoder(s). */ -case class ScalarUserDefinedFunction private ( +case class ScalarUserDefinedFunction private[sql] ( // SPARK-43198: Eagerly serialize to prevent the UDF from containing a reference to this class. serializedUdfPacket: Array[Byte], inputTypes: Seq[proto.DataType], diff --git a/connector/connect/client/jvm/src/test/resources/StubClassDummyUdf.scala b/connector/connect/client/jvm/src/test/resources/StubClassDummyUdf.scala new file mode 100644 index 000000000000..ff1b3deafafd --- /dev/null +++ b/connector/connect/client/jvm/src/test/resources/StubClassDummyUdf.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client + +// To generate a jar from the source file: +// `scalac StubClassDummyUdf.scala -d udf.jar` +// To remove class A from the jar: +// `jar -xvf udf.jar` -> delete A.class and A$.class +// `jar -cvf udf_noA.jar org/` +class StubClassDummyUdf { + val udf: Int => Int = (x: Int) => x + 1 + val dummy = (x: Int) => A(x) +} + +case class A(x: Int) { def get: Int = x + 5 } + +// The code to generate the udf file +object StubClassDummyUdf { + import java.io.{BufferedOutputStream, File, FileOutputStream} + import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.PrimitiveIntEncoder + import org.apache.spark.sql.connect.common.UdfPacket + import org.apache.spark.util.Utils + + def packDummyUdf(): String = { + val byteArray = + Utils.serialize[UdfPacket]( + new UdfPacket( + new StubClassDummyUdf().udf, + Seq(PrimitiveIntEncoder), + PrimitiveIntEncoder + ) + ) + val file = new File("src/test/resources/udf") + val target = new BufferedOutputStream(new FileOutputStream(file)) + try { + target.write(byteArray) + file.getAbsolutePath + } finally { + target.close + } + } +} diff --git a/connector/connect/client/jvm/src/test/resources/udf2.12 b/connector/connect/client/jvm/src/test/resources/udf2.12 new file mode 100644 index 000000000000..1090bc90d9b4 Binary files /dev/null and b/connector/connect/client/jvm/src/test/resources/udf2.12 differ diff --git a/connector/connect/client/jvm/src/test/resources/udf2.12.jar b/connector/connect/client/jvm/src/test/resources/udf2.12.jar new file mode 100644 index 000000000000..6ce6799678f9 Binary files /dev/null and b/connector/connect/client/jvm/src/test/resources/udf2.12.jar differ diff --git a/connector/connect/client/jvm/src/test/resources/udf2.13 b/connector/connect/client/jvm/src/test/resources/udf2.13 new file mode 100644 index 000000000000..863ac32a76dc Binary files /dev/null and b/connector/connect/client/jvm/src/test/resources/udf2.13 differ diff --git a/connector/connect/client/jvm/src/test/resources/udf2.13.jar b/connector/connect/client/jvm/src/test/resources/udf2.13.jar new file mode 100644 index 000000000000..c89830f127c0 Binary files /dev/null and b/connector/connect/client/jvm/src/test/resources/udf2.13.jar differ diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/UDFClassLoadingE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/UDFClassLoadingE2ESuite.scala new file mode 100644 index 000000000000..8fdb7efbcba7 --- /dev/null +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/UDFClassLoadingE2ESuite.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.client + +import java.io.File +import java.nio.file.{Files, Paths} + +import scala.util.Properties + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connect.client.util.RemoteSparkSession +import org.apache.spark.sql.connect.common.ProtoDataTypes +import org.apache.spark.sql.expressions.ScalarUserDefinedFunction + +class UDFClassLoadingE2ESuite extends RemoteSparkSession { + + private val scalaVersion = Properties.versionNumberString + .split("\\.") + .take(2) + .mkString(".") + + // See src/test/resources/StubClassDummyUdf for how the UDFs and jars are created. + private val udfByteArray: Array[Byte] = + Files.readAllBytes(Paths.get(s"src/test/resources/udf$scalaVersion")) + private val udfJar = + new File(s"src/test/resources/udf$scalaVersion.jar").toURI.toURL + + private def registerUdf(session: SparkSession): Unit = { + val udf = ScalarUserDefinedFunction( + serializedUdfPacket = udfByteArray, + inputTypes = Seq(ProtoDataTypes.IntegerType), + outputType = ProtoDataTypes.IntegerType, + name = Some("dummyUdf"), + nullable = true, + deterministic = true) + session.registerUdf(udf.toProto) + } + + test("update class loader after stubbing: new session") { + // Session1 should stub the missing class, but fail to call methods on it + val session1 = spark.newSession() + + assert( + intercept[Exception] { + registerUdf(session1) + }.getMessage.contains( + "java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf")) + + // Session2 uses the real class + val session2 = spark.newSession() + session2.addArtifact(udfJar.toURI) + registerUdf(session2) + } + + test("update class loader after stubbing: same session") { + // Session should stub the missing class, but fail to call methods on it + val session = spark.newSession() + + assert( + intercept[Exception] { + registerUdf(session) + }.getMessage.contains( + "java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf")) + + // Session uses the real class + session.addArtifact(udfJar.toURI) + registerUdf(session) + } +} diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala index 0eaca7577b92..3b88722f8c34 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala @@ -29,7 +29,7 @@ object IntegrationTestUtils { // System properties used for testing and debugging private val DEBUG_SC_JVM_CLIENT = "spark.debug.sc.jvm.client" - // Enable this flag to print all client debug log + server logs to the console + // Enable this flag to print all server logs to the console private[connect] val isDebug = System.getProperty(DEBUG_SC_JVM_CLIENT, "false").toBoolean private[sql] lazy val scalaVersion = { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala index 594d3c369fe6..1c1cb1403fee 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala @@ -96,7 +96,7 @@ object SparkConnectServerUtils { // To find InMemoryTableCatalog for V2 writer tests val catalystTestJar = tryFindJar("sql/catalyst", "spark-catalyst", "spark-catalyst", test = true) - .map(clientTestJar => Seq("--jars", clientTestJar.getCanonicalPath)) + .map(clientTestJar => Seq(clientTestJar.getCanonicalPath)) .getOrElse(Seq.empty) // For UDF maven E2E tests, the server needs the client code to find the UDFs defined in tests. diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala index d8f290639c2f..03391cef68b0 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala @@ -31,12 +31,13 @@ import org.apache.hadoop.fs.{LocalFileSystem, Path => FSPath} import org.apache.spark.{JobArtifactSet, JobArtifactState, SparkContext, SparkEnv} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.CONNECT_SCALA_UDF_STUB_CLASSES import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connect.artifact.util.ArtifactUtils import org.apache.spark.sql.connect.config.Connect.CONNECT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL import org.apache.spark.sql.connect.service.SessionHolder import org.apache.spark.storage.{CacheId, StorageLevel} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ChildFirstURLClassLoader, StubClassLoader, Utils} /** * The Artifact Manager for the [[SparkConnectService]]. @@ -161,7 +162,19 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging */ def classloader: ClassLoader = { val urls = getSparkConnectAddedJars :+ classDir.toUri.toURL - new URLClassLoader(urls.toArray, Utils.getContextOrSparkClassLoader) + val loader = if (SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_CLASSES).nonEmpty) { + val stubClassLoader = + StubClassLoader(null, SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_CLASSES)) + new ChildFirstURLClassLoader( + urls.toArray, + stubClassLoader, + Utils.getContextOrSparkClassLoader) + } else { + new URLClassLoader(urls.toArray, Utils.getContextOrSparkClassLoader) + } + + logDebug(s"Using class loader: $loader, containing urls: $urls") + loader } /** diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index e4ac34715fb2..ebed8af48f08 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.connect.planner +import java.io.IOException + import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Try @@ -1504,15 +1506,24 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { } private def unpackUdf(fun: proto.CommonInlineUserDefinedFunction): UdfPacket = { - Utils.deserialize[UdfPacket]( - fun.getScalarScalaUdf.getPayload.toByteArray, - Utils.getContextOrSparkClassLoader) + unpackScalarScalaUDF[UdfPacket](fun.getScalarScalaUdf) } private def unpackForeachWriter(fun: proto.ScalarScalaUDF): ForeachWriterPacket = { - Utils.deserialize[ForeachWriterPacket]( - fun.getPayload.toByteArray, - Utils.getContextOrSparkClassLoader) + unpackScalarScalaUDF[ForeachWriterPacket](fun) + } + + private def unpackScalarScalaUDF[T](fun: proto.ScalarScalaUDF): T = { + try { + logDebug(s"Unpack using class loader: ${Utils.getContextOrSparkClassLoader}") + Utils.deserialize[T](fun.getPayload.toByteArray, Utils.getContextOrSparkClassLoader) + } catch { + case e: IOException if e.getCause.isInstanceOf[NoSuchMethodException] => + throw new ClassNotFoundException( + s"Failed to load class correctly due to ${e.getCause}. " + + "Make sure the artifact where the class is defined is installed by calling" + + " session.addArtifact.") + } } /** diff --git a/connector/connect/server/src/test/resources/udf b/connector/connect/server/src/test/resources/udf new file mode 100644 index 000000000000..55a3264a017f Binary files /dev/null and b/connector/connect/server/src/test/resources/udf differ diff --git a/connector/connect/server/src/test/resources/udf_noA.jar b/connector/connect/server/src/test/resources/udf_noA.jar new file mode 100644 index 000000000000..4d8c423ab6df Binary files /dev/null and b/connector/connect/server/src/test/resources/udf_noA.jar differ diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/StubClassLoaderSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/StubClassLoaderSuite.scala new file mode 100644 index 000000000000..0f6e05431518 --- /dev/null +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/StubClassLoaderSuite.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.artifact + +import java.io.File + +import org.apache.spark.SparkFunSuite +import org.apache.spark.util.{ChildFirstURLClassLoader, StubClassLoader} + +class StubClassLoaderSuite extends SparkFunSuite { + + // See src/test/resources/StubClassDummyUdf for how the UDFs and jars are created. + private val udfNoAJar = new File("src/test/resources/udf_noA.jar").toURI.toURL + private val classDummyUdf = "org.apache.spark.sql.connect.client.StubClassDummyUdf" + private val classA = "org.apache.spark.sql.connect.client.A" + + test("find class with stub class") { + val cl = new RecordedStubClassLoader(getClass().getClassLoader(), _ => true) + val cls = cl.findClass("my.name.HelloWorld") + assert(cls.getName === "my.name.HelloWorld") + assert(cl.lastStubbed === "my.name.HelloWorld") + } + + test("class for name with stub class") { + val cl = new RecordedStubClassLoader(getClass().getClassLoader(), _ => true) + // scalastyle:off classforname + val cls = Class.forName("my.name.HelloWorld", false, cl) + // scalastyle:on classforname + assert(cls.getName === "my.name.HelloWorld") + assert(cl.lastStubbed === "my.name.HelloWorld") + } + + test("filter class to stub") { + val list = "my.name" :: Nil + val cl = StubClassLoader(getClass().getClassLoader(), list) + val cls = cl.findClass("my.name.HelloWorld") + assert(cls.getName === "my.name.HelloWorld") + + intercept[ClassNotFoundException] { + cl.findClass("name.my.GoodDay") + } + } + + test("stub missing class") { + val sysClassLoader = getClass.getClassLoader() + val stubClassLoader = new RecordedStubClassLoader(null, _ => true) + + // Install artifact without class A. + val sessionClassLoader = + new ChildFirstURLClassLoader(Array(udfNoAJar), stubClassLoader, sysClassLoader) + // Load udf with A used in the same class. + loadDummyUdf(sessionClassLoader) + // Class A should be stubbed. + assert(stubClassLoader.lastStubbed === classA) + } + + test("unload stub class") { + val sysClassLoader = getClass.getClassLoader() + val stubClassLoader = new RecordedStubClassLoader(null, _ => true) + + val cl1 = new ChildFirstURLClassLoader(Array.empty, stubClassLoader, sysClassLoader) + + // Failed to load DummyUdf + intercept[Exception] { + loadDummyUdf(cl1) + } + // Successfully stubbed the missing class. + assert(stubClassLoader.lastStubbed === classDummyUdf) + + // Creating a new class loader will unpack the udf correctly. + val cl2 = new ChildFirstURLClassLoader( + Array(udfNoAJar), + stubClassLoader, // even with the same stub class loader. + sysClassLoader) + // Should be able to load after the artifact is added + loadDummyUdf(cl2) + } + + test("throw no such method if trying to access methods on stub class") { + val sysClassLoader = getClass.getClassLoader() + val stubClassLoader = new RecordedStubClassLoader(null, _ => true) + + val sessionClassLoader = + new ChildFirstURLClassLoader(Array.empty, stubClassLoader, sysClassLoader) + + // Failed to load DummyUdf because of missing methods + assert(intercept[NoSuchMethodException] { + loadDummyUdf(sessionClassLoader) + }.getMessage.contains(classDummyUdf)) + // Successfully stubbed the missing class. + assert(stubClassLoader.lastStubbed === classDummyUdf) + } + + private def loadDummyUdf(sessionClassLoader: ClassLoader): Unit = { + // Load DummyUdf and call a method on it. + // scalastyle:off classforname + val cls = Class.forName(classDummyUdf, false, sessionClassLoader) + // scalastyle:on classforname + cls.getDeclaredMethod("dummy") + + // Load class A used inside DummyUdf + // scalastyle:off classforname + Class.forName(classA, false, sessionClassLoader) + // scalastyle:on classforname + } +} + +class RecordedStubClassLoader(parent: ClassLoader, shouldStub: String => Boolean) + extends StubClassLoader(parent, shouldStub) { + var lastStubbed: String = _ + + override def findClass(name: String): Class[_] = { + if (shouldStub(name)) { + lastStubbed = name + } + super.findClass(name) + } +} diff --git a/core/src/main/java/org/apache/spark/util/ChildFirstURLClassLoader.java b/core/src/main/java/org/apache/spark/util/ChildFirstURLClassLoader.java index 57d96756c8be..2791209e019b 100644 --- a/core/src/main/java/org/apache/spark/util/ChildFirstURLClassLoader.java +++ b/core/src/main/java/org/apache/spark/util/ChildFirstURLClassLoader.java @@ -40,6 +40,15 @@ public ChildFirstURLClassLoader(URL[] urls, ClassLoader parent) { this.parent = new ParentClassLoader(parent); } + /** + * Specify the grandparent if there is a need to load in the order of + * `grandparent -> urls (child) -> parent`. + */ + public ChildFirstURLClassLoader(URL[] urls, ClassLoader parent, ClassLoader grandparent) { + super(urls, grandparent); + this.parent = new ParentClassLoader(parent); + } + @Override public Class loadClass(String name, boolean resolve) throws ClassNotFoundException { try { diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index b30569dc9641..9327ea4d3dd7 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -56,11 +56,12 @@ import org.apache.spark.util._ private[spark] class IsolatedSessionState( val sessionUUID: String, - val urlClassLoader: MutableURLClassLoader, + var urlClassLoader: MutableURLClassLoader, var replClassLoader: ClassLoader, val currentFiles: HashMap[String, Long], val currentJars: HashMap[String, Long], - val currentArchives: HashMap[String, Long]) + val currentArchives: HashMap[String, Long], + val replClassDirUri: Option[String]) /** * Spark executor, backed by a threadpool to run tasks. @@ -173,14 +174,20 @@ private[spark] class Executor( val currentFiles = new HashMap[String, Long] val currentJars = new HashMap[String, Long] val currentArchives = new HashMap[String, Long] - val urlClassLoader = createClassLoader(currentJars) + val urlClassLoader = createClassLoader(currentJars, !isDefaultState(jobArtifactState.uuid)) val replClassLoader = addReplClassLoaderIfNeeded( - urlClassLoader, jobArtifactState.replClassDirUri) + urlClassLoader, jobArtifactState.replClassDirUri, jobArtifactState.uuid) new IsolatedSessionState( jobArtifactState.uuid, urlClassLoader, replClassLoader, - currentFiles, currentJars, currentArchives) + currentFiles, + currentJars, + currentArchives, + jobArtifactState.replClassDirUri + ) } + private def isDefaultState(name: String) = name == "default" + // Classloader isolation // The default isolation group val defaultSessionState = newSessionState(JobArtifactState("default", None)) @@ -514,9 +521,8 @@ private[spark] class Executor( // Classloader isolation val isolatedSession = taskDescription.artifacts.state match { - case Some(jobArtifactState) => isolatedSessionCache.get( - jobArtifactState.uuid, - () => newSessionState(jobArtifactState)) + case Some(jobArtifactState) => + isolatedSessionCache.get(jobArtifactState.uuid, () => newSessionState(jobArtifactState)) case _ => defaultSessionState } @@ -548,6 +554,9 @@ private[spark] class Executor( taskDescription.artifacts.jars, taskDescription.artifacts.archives, isolatedSession) + // Always reset the thread class loader to ensure if any updates, all threads (not only + // the thread that updated the dependencies) can update to the new class loader. + Thread.currentThread.setContextClassLoader(isolatedSession.replClassLoader) task = ser.deserialize[Task[Any]]( taskDescription.serializedTask, Thread.currentThread.getContextClassLoader) task.localProperties = taskDescription.properties @@ -999,7 +1008,9 @@ private[spark] class Executor( * Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes * created by the interpreter to the search path */ - private def createClassLoader(currentJars: HashMap[String, Long]): MutableURLClassLoader = { + private def createClassLoader( + currentJars: HashMap[String, Long], + useStub: Boolean): MutableURLClassLoader = { // Bootstrap the list of jars with the user class path. val now = System.currentTimeMillis() userClassPath.foreach { url => @@ -1011,8 +1022,23 @@ private[spark] class Executor( val urls = userClassPath.toArray ++ currentJars.keySet.map { uri => new File(uri.split("/").last).toURI.toURL } - logInfo(s"Starting executor with user classpath (userClassPathFirst = $userClassPathFirst): " + - urls.mkString("'", ",", "'")) + createClassLoader(urls, useStub) + } + + private def createClassLoader(urls: Array[URL], useStub: Boolean): MutableURLClassLoader = { + logInfo( + s"Starting executor with user classpath (userClassPathFirst = $userClassPathFirst): " + + urls.mkString("'", ",", "'") + ) + + if (useStub && conf.get(CONNECT_SCALA_UDF_STUB_CLASSES).nonEmpty) { + createClassLoaderWithStub(urls, conf.get(CONNECT_SCALA_UDF_STUB_CLASSES)) + } else { + createClassLoader(urls) + } + } + + private def createClassLoader(urls: Array[URL]): MutableURLClassLoader = { if (userClassPathFirst) { new ChildFirstURLClassLoader(urls, systemLoader) } else { @@ -1020,20 +1046,39 @@ private[spark] class Executor( } } + private def createClassLoaderWithStub( + urls: Array[URL], + binaryName: Seq[String]): MutableURLClassLoader = { + if (userClassPathFirst) { + // user -> (sys -> stub) + val stubClassLoader = + StubClassLoader(systemLoader, binaryName) + new ChildFirstURLClassLoader(urls, stubClassLoader) + } else { + // sys -> user -> stub + val stubClassLoader = + StubClassLoader(null, binaryName) + new ChildFirstURLClassLoader(urls, stubClassLoader, systemLoader) + } + } + /** * If the REPL is in use, add another ClassLoader that will read * new classes defined by the REPL as the user types code */ private def addReplClassLoaderIfNeeded( parent: ClassLoader, - sessionClassUri: Option[String]): ClassLoader = { + sessionClassUri: Option[String], + sessionUUID: String): ClassLoader = { val classUri = sessionClassUri.getOrElse(conf.get("spark.repl.class.uri", null)) - if (classUri != null) { + val classLoader = if (classUri != null) { logInfo("Using REPL class URI: " + classUri) new ExecutorClassLoader(conf, env, classUri, parent, userClassPathFirst) } else { parent } + logInfo(s"Created or updated repl class loader $classLoader for $sessionUUID.") + classLoader } /** @@ -1048,6 +1093,7 @@ private[spark] class Executor( state: IsolatedSessionState, testStartLatch: Option[CountDownLatch] = None, testEndLatch: Option[CountDownLatch] = None): Unit = { + var updated = false; lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) updateDependenciesLock.lockInterruptibly() try { @@ -1056,7 +1102,7 @@ private[spark] class Executor( // If the session ID was specified from SparkSession, it's from a Spark Connect client. // Specify a dedicated directory for Spark Connect client. - lazy val root = if (state.sessionUUID != "default") { + lazy val root = if (!isDefaultState(state.sessionUUID)) { val newDest = new File(SparkFiles.getRootDirectory(), state.sessionUUID) newDest.mkdir() newDest @@ -1101,11 +1147,21 @@ private[spark] class Executor( // Add it to our class loader val url = new File(root, localName).toURI.toURL if (!state.urlClassLoader.getURLs().contains(url)) { - logInfo(s"Adding $url to class loader") + logInfo(s"Adding $url to class loader ${state.sessionUUID}") state.urlClassLoader.addURL(url) + if (!isDefaultState(state.sessionUUID)) { + updated = true + } } } } + if (updated) { + // When a new url is added for non-default class loader, recreate the class loader + // to ensure all classes are updated. + state.urlClassLoader = createClassLoader(state.urlClassLoader.getURLs, useStub = true) + state.replClassLoader = + addReplClassLoaderIfNeeded(state.urlClassLoader, state.replClassDirUri, state.sessionUUID) + } // For testing, so we can simulate a slow file download: testEndLatch.foreach(_.await()) } finally { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 04eba8bddeb6..06c512cee9da 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -2547,4 +2547,18 @@ package object config { .version("3.5.0") .booleanConf .createWithDefault(false) + + private[spark] val CONNECT_SCALA_UDF_STUB_CLASSES = + ConfigBuilder("spark.connect.scalaUdf.stubClasses") + .internal() + .doc(""" + |Comma-separated list of binary names of classes/packages that should be stubbed during + |the Scala UDF serde and execution if not found on the server classpath. + |An empty list effectively disables stubbing for all missing classes. + |By default, the server stubs classes from the Scala client package. + |""".stripMargin) + .version("3.5.0") + .stringConf + .toSequence + .createWithDefault("org.apache.spark.sql.connect.client" :: Nil) } diff --git a/core/src/main/scala/org/apache/spark/util/StubClassLoader.scala b/core/src/main/scala/org/apache/spark/util/StubClassLoader.scala new file mode 100644 index 000000000000..a0bc753f4887 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/StubClassLoader.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util + +import org.apache.xbean.asm9.{ClassWriter, Opcodes} + +/** + * [[ClassLoader]] that replaces missing classes with stubs, if the cannot be found. It will only + * do this for classes that are marked for stubbing. + * + * While this is generally not a good idea. In this particular case this is used to load lambda's + * whose capturing class contains unknown (and unneeded) classes. The lambda itself does not need + * the class and therefor is safe to replace by a stub. + */ +class StubClassLoader(parent: ClassLoader, shouldStub: String => Boolean) + extends ClassLoader(parent) { + override def findClass(name: String): Class[_] = { + if (!shouldStub(name)) { + throw new ClassNotFoundException(name) + } + val bytes = StubClassLoader.generateStub(name) + defineClass(name, bytes, 0, bytes.length) + } +} + +object StubClassLoader { + def apply(parent: ClassLoader, binaryName: Seq[String]): StubClassLoader = { + new StubClassLoader(parent, name => binaryName.exists(p => name.startsWith(p))) + } + + def generateStub(binaryName: String): Array[Byte] = { + // Convert binary names to internal names. + val name = binaryName.replace('.', '/') + val classWriter = new ClassWriter(0) + classWriter.visit( + 49, + Opcodes.ACC_PUBLIC + Opcodes.ACC_SUPER, + name, + null, + "java/lang/Object", + null) + classWriter.visitSource(name + ".java", null) + + // Generate constructor. + val ctorWriter = classWriter.visitMethod( + Opcodes.ACC_PUBLIC, + "", + "()V", + null, + null) + ctorWriter.visitVarInsn(Opcodes.ALOAD, 0) + ctorWriter.visitMethodInsn( + Opcodes.INVOKESPECIAL, + "java/lang/Object", + "", + "()V", + false) + + ctorWriter.visitInsn(Opcodes.RETURN) + ctorWriter.visitMaxs(1, 1) + ctorWriter.visitEnd() + classWriter.visitEnd() + classWriter.toByteArray + } +}