diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala index 03391cef68b0..c1dd7820c55f 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.fs.{LocalFileSystem, Path => FSPath} import org.apache.spark.{JobArtifactSet, JobArtifactState, SparkContext, SparkEnv} import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.CONNECT_SCALA_UDF_STUB_CLASSES +import org.apache.spark.internal.config.CONNECT_SCALA_UDF_STUB_PREFIXES import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connect.artifact.util.ArtifactUtils import org.apache.spark.sql.connect.config.Connect.CONNECT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL @@ -162,9 +162,9 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging */ def classloader: ClassLoader = { val urls = getSparkConnectAddedJars :+ classDir.toUri.toURL - val loader = if (SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_CLASSES).nonEmpty) { + val loader = if (SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_PREFIXES).nonEmpty) { val stubClassLoader = - StubClassLoader(null, SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_CLASSES)) + StubClassLoader(null, SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_PREFIXES)) new ChildFirstURLClassLoader( urls.toArray, stubClassLoader, diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/StubClassLoaderSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/StubClassLoaderSuite.scala index 0f6e05431518..bde9a71fa17e 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/StubClassLoaderSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/StubClassLoaderSuite.scala @@ -55,6 +55,21 @@ class StubClassLoaderSuite extends SparkFunSuite { } } + test("call stub class default constructor") { + val cl = new RecordedStubClassLoader(getClass().getClassLoader(), _ => true) + // scalastyle:off classforname + val cls = Class.forName("my.name.HelloWorld", false, cl) + // scalastyle:on classforname + assert(cl.lastStubbed === "my.name.HelloWorld") + val error = intercept[java.lang.reflect.InvocationTargetException] { + cls.getDeclaredConstructor().newInstance() + } + assert( + error.getCause != null && error.getCause.getMessage.contains( + "Fail to initiate the class my.name.HelloWorld because it is stubbed"), + error) + } + test("stub missing class") { val sysClassLoader = getClass.getClassLoader() val stubClassLoader = new RecordedStubClassLoader(null, _ => true) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 9327ea4d3dd7..1b7bb8af79a9 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -174,7 +174,8 @@ private[spark] class Executor( val currentFiles = new HashMap[String, Long] val currentJars = new HashMap[String, Long] val currentArchives = new HashMap[String, Long] - val urlClassLoader = createClassLoader(currentJars, !isDefaultState(jobArtifactState.uuid)) + val urlClassLoader = + createClassLoader(currentJars, isStubbingEnabledForState(jobArtifactState.uuid)) val replClassLoader = addReplClassLoaderIfNeeded( urlClassLoader, jobArtifactState.replClassDirUri, jobArtifactState.uuid) new IsolatedSessionState( @@ -186,6 +187,11 @@ private[spark] class Executor( ) } + private def isStubbingEnabledForState(name: String) = { + !isDefaultState(name) && + conf.get(CONNECT_SCALA_UDF_STUB_PREFIXES).nonEmpty + } + private def isDefaultState(name: String) = name == "default" // Classloader isolation @@ -1031,8 +1037,8 @@ private[spark] class Executor( urls.mkString("'", ",", "'") ) - if (useStub && conf.get(CONNECT_SCALA_UDF_STUB_CLASSES).nonEmpty) { - createClassLoaderWithStub(urls, conf.get(CONNECT_SCALA_UDF_STUB_CLASSES)) + if (useStub) { + createClassLoaderWithStub(urls, conf.get(CONNECT_SCALA_UDF_STUB_PREFIXES)) } else { createClassLoader(urls) } @@ -1093,7 +1099,7 @@ private[spark] class Executor( state: IsolatedSessionState, testStartLatch: Option[CountDownLatch] = None, testEndLatch: Option[CountDownLatch] = None): Unit = { - var updated = false; + var renewClassLoader = false; lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) updateDependenciesLock.lockInterruptibly() try { @@ -1149,15 +1155,14 @@ private[spark] class Executor( if (!state.urlClassLoader.getURLs().contains(url)) { logInfo(s"Adding $url to class loader ${state.sessionUUID}") state.urlClassLoader.addURL(url) - if (!isDefaultState(state.sessionUUID)) { - updated = true + if (isStubbingEnabledForState(state.sessionUUID)) { + renewClassLoader = true } } } } - if (updated) { - // When a new url is added for non-default class loader, recreate the class loader - // to ensure all classes are updated. + if (renewClassLoader) { + // Recreate the class loader to ensure all classes are updated. state.urlClassLoader = createClassLoader(state.urlClassLoader.getURLs, useStub = true) state.replClassLoader = addReplClassLoaderIfNeeded(state.urlClassLoader, state.replClassDirUri, state.sessionUUID) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index ba809b7a3b1b..81226706d677 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -2556,8 +2556,8 @@ package object config { .booleanConf .createWithDefault(false) - private[spark] val CONNECT_SCALA_UDF_STUB_CLASSES = - ConfigBuilder("spark.connect.scalaUdf.stubClasses") + private[spark] val CONNECT_SCALA_UDF_STUB_PREFIXES = + ConfigBuilder("spark.connect.scalaUdf.stubPrefixes") .internal() .doc(""" |Comma-separated list of binary names of classes/packages that should be stubbed during diff --git a/core/src/main/scala/org/apache/spark/util/StubClassLoader.scala b/core/src/main/scala/org/apache/spark/util/StubClassLoader.scala index a0bc753f4887..e27376e2b83d 100644 --- a/core/src/main/scala/org/apache/spark/util/StubClassLoader.scala +++ b/core/src/main/scala/org/apache/spark/util/StubClassLoader.scala @@ -70,8 +70,22 @@ object StubClassLoader { "()V", false) - ctorWriter.visitInsn(Opcodes.RETURN) - ctorWriter.visitMaxs(1, 1) + val internalException: String = "java/lang/ClassNotFoundException" + ctorWriter.visitTypeInsn(Opcodes.NEW, internalException) + ctorWriter.visitInsn(Opcodes.DUP) + ctorWriter.visitLdcInsn( + s"Fail to initiate the class $binaryName because it is stubbed. " + + "Please install the artifact of the missing class by calling session.addArtifact.") + // Invoke throwable constructor + ctorWriter.visitMethodInsn( + Opcodes.INVOKESPECIAL, + internalException, + "", + "(Ljava/lang/String;)V", + false) + + ctorWriter.visitInsn(Opcodes.ATHROW) + ctorWriter.visitMaxs(3, 3) ctorWriter.visitEnd() classWriter.visitEnd() classWriter.toByteArray