Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.hadoop.fs.{LocalFileSystem, Path => FSPath}

import org.apache.spark.{JobArtifactSet, JobArtifactState, SparkContext, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.CONNECT_SCALA_UDF_STUB_CLASSES
import org.apache.spark.internal.config.CONNECT_SCALA_UDF_STUB_PREFIXES
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.artifact.util.ArtifactUtils
import org.apache.spark.sql.connect.config.Connect.CONNECT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL
Expand Down Expand Up @@ -162,9 +162,9 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging
*/
def classloader: ClassLoader = {
val urls = getSparkConnectAddedJars :+ classDir.toUri.toURL
val loader = if (SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_CLASSES).nonEmpty) {
val loader = if (SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_PREFIXES).nonEmpty) {
val stubClassLoader =
StubClassLoader(null, SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_CLASSES))
StubClassLoader(null, SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_PREFIXES))
new ChildFirstURLClassLoader(
urls.toArray,
stubClassLoader,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,21 @@ class StubClassLoaderSuite extends SparkFunSuite {
}
}

test("call stub class default constructor") {
val cl = new RecordedStubClassLoader(getClass().getClassLoader(), _ => true)
// scalastyle:off classforname
val cls = Class.forName("my.name.HelloWorld", false, cl)
// scalastyle:on classforname
assert(cl.lastStubbed === "my.name.HelloWorld")
val error = intercept[java.lang.reflect.InvocationTargetException] {
cls.getDeclaredConstructor().newInstance()
}
assert(
error.getCause != null && error.getCause.getMessage.contains(
"Fail to initiate the class my.name.HelloWorld because it is stubbed"),
error)
}

test("stub missing class") {
val sysClassLoader = getClass.getClassLoader()
val stubClassLoader = new RecordedStubClassLoader(null, _ => true)
Expand Down
23 changes: 14 additions & 9 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@ private[spark] class Executor(
val currentFiles = new HashMap[String, Long]
val currentJars = new HashMap[String, Long]
val currentArchives = new HashMap[String, Long]
val urlClassLoader = createClassLoader(currentJars, !isDefaultState(jobArtifactState.uuid))
val urlClassLoader =
createClassLoader(currentJars, isStubbingEnabledForState(jobArtifactState.uuid))
val replClassLoader = addReplClassLoaderIfNeeded(
urlClassLoader, jobArtifactState.replClassDirUri, jobArtifactState.uuid)
new IsolatedSessionState(
Expand All @@ -186,6 +187,11 @@ private[spark] class Executor(
)
}

private def isStubbingEnabledForState(name: String) = {
!isDefaultState(name) &&
conf.get(CONNECT_SCALA_UDF_STUB_PREFIXES).nonEmpty
}

private def isDefaultState(name: String) = name == "default"

// Classloader isolation
Expand Down Expand Up @@ -1031,8 +1037,8 @@ private[spark] class Executor(
urls.mkString("'", ",", "'")
)

if (useStub && conf.get(CONNECT_SCALA_UDF_STUB_CLASSES).nonEmpty) {
createClassLoaderWithStub(urls, conf.get(CONNECT_SCALA_UDF_STUB_CLASSES))
if (useStub) {
createClassLoaderWithStub(urls, conf.get(CONNECT_SCALA_UDF_STUB_PREFIXES))
} else {
createClassLoader(urls)
}
Expand Down Expand Up @@ -1093,7 +1099,7 @@ private[spark] class Executor(
state: IsolatedSessionState,
testStartLatch: Option[CountDownLatch] = None,
testEndLatch: Option[CountDownLatch] = None): Unit = {
var updated = false;
var renewClassLoader = false;
lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
updateDependenciesLock.lockInterruptibly()
try {
Expand Down Expand Up @@ -1149,15 +1155,14 @@ private[spark] class Executor(
if (!state.urlClassLoader.getURLs().contains(url)) {
logInfo(s"Adding $url to class loader ${state.sessionUUID}")
state.urlClassLoader.addURL(url)
if (!isDefaultState(state.sessionUUID)) {
updated = true
if (isStubbingEnabledForState(state.sessionUUID)) {
renewClassLoader = true
}
}
}
}
if (updated) {
// When a new url is added for non-default class loader, recreate the class loader
// to ensure all classes are updated.
if (renewClassLoader) {
// Recreate the class loader to ensure all classes are updated.
state.urlClassLoader = createClassLoader(state.urlClassLoader.getURLs, useStub = true)
state.replClassLoader =
addReplClassLoaderIfNeeded(state.urlClassLoader, state.replClassDirUri, state.sessionUUID)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2556,8 +2556,8 @@ package object config {
.booleanConf
.createWithDefault(false)

private[spark] val CONNECT_SCALA_UDF_STUB_CLASSES =
ConfigBuilder("spark.connect.scalaUdf.stubClasses")
private[spark] val CONNECT_SCALA_UDF_STUB_PREFIXES =
ConfigBuilder("spark.connect.scalaUdf.stubPrefixes")
.internal()
.doc("""
|Comma-separated list of binary names of classes/packages that should be stubbed during
Expand Down
18 changes: 16 additions & 2 deletions core/src/main/scala/org/apache/spark/util/StubClassLoader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,22 @@ object StubClassLoader {
"()V",
false)

ctorWriter.visitInsn(Opcodes.RETURN)
ctorWriter.visitMaxs(1, 1)
val internalException: String = "java/lang/ClassNotFoundException"
ctorWriter.visitTypeInsn(Opcodes.NEW, internalException)
ctorWriter.visitInsn(Opcodes.DUP)
ctorWriter.visitLdcInsn(
s"Fail to initiate the class $binaryName because it is stubbed. " +
"Please install the artifact of the missing class by calling session.addArtifact.")
// Invoke throwable constructor
ctorWriter.visitMethodInsn(
Opcodes.INVOKESPECIAL,
internalException,
"<init>",
"(Ljava/lang/String;)V",
false)

ctorWriter.visitInsn(Opcodes.ATHROW)
ctorWriter.visitMaxs(3, 3)
ctorWriter.visitEnd()
classWriter.visitEnd()
classWriter.toByteArray
Expand Down