Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ class SparkSession private[sql] (
val command = proto.Command.newBuilder().setRegisterFunction(udf).build()
val plan = proto.Plan.newBuilder().setCommand(command).build()

client.execute(plan)
client.execute(plan).asScala.foreach(_ => ())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently the registerUDF call is async. I do not feel it is correct to have registerUDF to be async, so added the code to block for success or error.

}

@DeveloperApi
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ sealed abstract class UserDefinedFunction {
/**
* Holder class for a scalar user-defined function and it's input/output encoder(s).
*/
case class ScalarUserDefinedFunction private (
case class ScalarUserDefinedFunction private[sql] (
// SPARK-43198: Eagerly serialize to prevent the UDF from containing a reference to this class.
serializedUdfPacket: Array[Byte],
inputTypes: Seq[proto.DataType],
Expand Down
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Location is a bit weird, why not in src/test/scala?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This source file cannot be on the classpath, otherwise sbt would include it in the server system classpath. So it is outside in resources. We only needs the jars and binaries, which will be manually installed in session classpath. Keeping the source file is just in case anyone wondering what the dummy udf looks like.

Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connect.client

// To generate a jar from the source file:
// `scalac StubClassDummyUdf.scala -d udf.jar`
// To remove class A from the jar:
// `jar -xvf udf.jar` -> delete A.class and A$.class
// `jar -cvf udf_noA.jar org/`
class StubClassDummyUdf {
val udf: Int => Int = (x: Int) => x + 1
val dummy = (x: Int) => A(x)
}

case class A(x: Int) { def get: Int = x + 5 }

// The code to generate the udf file
object StubClassDummyUdf {
import java.io.{BufferedOutputStream, File, FileOutputStream}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.PrimitiveIntEncoder
import org.apache.spark.sql.connect.common.UdfPacket
import org.apache.spark.util.Utils

def packDummyUdf(): String = {
val byteArray =
Utils.serialize[UdfPacket](
new UdfPacket(
new StubClassDummyUdf().udf,
Seq(PrimitiveIntEncoder),
PrimitiveIntEncoder
)
)
val file = new File("src/test/resources/udf")
val target = new BufferedOutputStream(new FileOutputStream(file))
try {
target.write(byteArray)
file.getAbsolutePath
} finally {
target.close
}
}
}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connect.client

import java.io.File
import java.nio.file.{Files, Paths}

import scala.util.Properties

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.client.util.RemoteSparkSession
import org.apache.spark.sql.connect.common.ProtoDataTypes
import org.apache.spark.sql.expressions.ScalarUserDefinedFunction

class UDFClassLoadingE2ESuite extends RemoteSparkSession {

private val scalaVersion = Properties.versionNumberString
.split("\\.")
.take(2)
.mkString(".")

// See src/test/resources/StubClassDummyUdf for how the UDFs and jars are created.
private val udfByteArray: Array[Byte] =
Files.readAllBytes(Paths.get(s"src/test/resources/udf$scalaVersion"))
private val udfJar =
new File(s"src/test/resources/udf$scalaVersion.jar").toURI.toURL

private def registerUdf(session: SparkSession): Unit = {
val udf = ScalarUserDefinedFunction(
serializedUdfPacket = udfByteArray,
inputTypes = Seq(ProtoDataTypes.IntegerType),
outputType = ProtoDataTypes.IntegerType,
name = Some("dummyUdf"),
nullable = true,
deterministic = true)
session.registerUdf(udf.toProto)
}

test("update class loader after stubbing: new session") {
// Session1 should stub the missing class, but fail to call methods on it
val session1 = spark.newSession()

assert(
intercept[Exception] {
registerUdf(session1)
}.getMessage.contains(
"java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf"))

// Session2 uses the real class
val session2 = spark.newSession()
session2.addArtifact(udfJar.toURI)
registerUdf(session2)
}

test("update class loader after stubbing: same session") {
// Session should stub the missing class, but fail to call methods on it
val session = spark.newSession()

assert(
intercept[Exception] {
registerUdf(session)
}.getMessage.contains(
"java.lang.NoSuchMethodException: org.apache.spark.sql.connect.client.StubClassDummyUdf"))

// Session uses the real class
session.addArtifact(udfJar.toURI)
registerUdf(session)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ object IntegrationTestUtils {

// System properties used for testing and debugging
private val DEBUG_SC_JVM_CLIENT = "spark.debug.sc.jvm.client"
// Enable this flag to print all client debug log + server logs to the console
// Enable this flag to print all server logs to the console
private[connect] val isDebug = System.getProperty(DEBUG_SC_JVM_CLIENT, "false").toBoolean

private[sql] lazy val scalaVersion = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ object SparkConnectServerUtils {
// To find InMemoryTableCatalog for V2 writer tests
val catalystTestJar =
tryFindJar("sql/catalyst", "spark-catalyst", "spark-catalyst", test = true)
.map(clientTestJar => Seq("--jars", clientTestJar.getCanonicalPath))
.map(clientTestJar => Seq(clientTestJar.getCanonicalPath))
.getOrElse(Seq.empty)

// For UDF maven E2E tests, the server needs the client code to find the UDFs defined in tests.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@ import org.apache.hadoop.fs.{LocalFileSystem, Path => FSPath}

import org.apache.spark.{JobArtifactSet, JobArtifactState, SparkContext, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.CONNECT_SCALA_UDF_STUB_CLASSES
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.artifact.util.ArtifactUtils
import org.apache.spark.sql.connect.config.Connect.CONNECT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL
import org.apache.spark.sql.connect.service.SessionHolder
import org.apache.spark.storage.{CacheId, StorageLevel}
import org.apache.spark.util.Utils
import org.apache.spark.util.{ChildFirstURLClassLoader, StubClassLoader, Utils}

/**
* The Artifact Manager for the [[SparkConnectService]].
Expand Down Expand Up @@ -161,7 +162,19 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging
*/
def classloader: ClassLoader = {
val urls = getSparkConnectAddedJars :+ classDir.toUri.toURL
new URLClassLoader(urls.toArray, Utils.getContextOrSparkClassLoader)
val loader = if (SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_CLASSES).nonEmpty) {
val stubClassLoader =
StubClassLoader(null, SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_CLASSES))
new ChildFirstURLClassLoader(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this follow the same rules for classpath resolution we have on the executor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably should to be consistent. Let me fix in a followup.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it is fine. There are 3 existing class loader:

User CL : classes added using --jar
Sys CL: Spark + sys libs
Session CL: classes added using session.addArtifacts

In Executor:

  • normal: Sys -> (User + Session) -> Stub
  • reverse: (User + Session) -> Sys -> Stub

In Driver:

  • normal: (Sys + User) -> Session -> Stub
  • reverse: (User -> Sys) -> Session -> Stub

So here what you saw is () -> Session -> Stub.

urls.toArray,
stubClassLoader,
Utils.getContextOrSparkClassLoader)
} else {
new URLClassLoader(urls.toArray, Utils.getContextOrSparkClassLoader)
}

logDebug(s"Using class loader: $loader, containing urls: $urls")
loader
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.connect.planner

import java.io.IOException

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.Try
Expand Down Expand Up @@ -1504,15 +1506,24 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
}

private def unpackUdf(fun: proto.CommonInlineUserDefinedFunction): UdfPacket = {
Utils.deserialize[UdfPacket](
fun.getScalarScalaUdf.getPayload.toByteArray,
Utils.getContextOrSparkClassLoader)
unpackScalarScalaUDF[UdfPacket](fun.getScalarScalaUdf)
}

private def unpackForeachWriter(fun: proto.ScalarScalaUDF): ForeachWriterPacket = {
Utils.deserialize[ForeachWriterPacket](
fun.getPayload.toByteArray,
Utils.getContextOrSparkClassLoader)
unpackScalarScalaUDF[ForeachWriterPacket](fun)
}

private def unpackScalarScalaUDF[T](fun: proto.ScalarScalaUDF): T = {
try {
logDebug(s"Unpack using class loader: ${Utils.getContextOrSparkClassLoader}")
Utils.deserialize[T](fun.getPayload.toByteArray, Utils.getContextOrSparkClassLoader)
} catch {
case e: IOException if e.getCause.isInstanceOf[NoSuchMethodException] =>
throw new ClassNotFoundException(
s"Failed to load class correctly due to ${e.getCause}. " +
"Make sure the artifact where the class is defined is installed by calling" +
" session.addArtifact.")
Comment on lines +1521 to +1525
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the description you write

If the user code is actually needed to execute the UDF, we will return an error message to suggest the user to add the missing classes using the addArtifact method.

but since this triggers during deserialization, wouldn't this trigger also for a class that is not actually used, just accidentally pulled in, and not captured by the CONNECT_SCALA_UDF_STUB_CLASSES config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't this trigger also for a class that is not actually used, just accidentally pulled in, and not captured by the CONNECT_SCALA_UDF_STUB_CLASSES config

This code you highlighted would not catch this class. Because your described case would fail with a NoClassFoundException rather than a NoSuchMethodException.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, smart, that's why you catch NoSuchMethodException, because that would suggest actual use, and for NoSuchClassException generate a stub, now I finally understand from your other comment with explanation.
That could be worth a rubber ducky comment here saying that "while NoSuchClassException may be caused by an unused class accidentally pulled by the serializer, NoSuchMethodException suggests actual use of the class".

And @hvanhovell comment about throwing from default constructor is to cover the case where someone just calls the default constructor, but doesn't use any methods?
Also worth a rubber ducky comment :-)

}
}

/**
Expand Down
Binary file added connector/connect/server/src/test/resources/udf
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connect.artifact

import java.io.File

import org.apache.spark.SparkFunSuite
import org.apache.spark.util.{ChildFirstURLClassLoader, StubClassLoader}

class StubClassLoaderSuite extends SparkFunSuite {

// See src/test/resources/StubClassDummyUdf for how the UDFs and jars are created.
private val udfNoAJar = new File("src/test/resources/udf_noA.jar").toURI.toURL
private val classDummyUdf = "org.apache.spark.sql.connect.client.StubClassDummyUdf"
private val classA = "org.apache.spark.sql.connect.client.A"

test("find class with stub class") {
val cl = new RecordedStubClassLoader(getClass().getClassLoader(), _ => true)
val cls = cl.findClass("my.name.HelloWorld")
assert(cls.getName === "my.name.HelloWorld")
assert(cl.lastStubbed === "my.name.HelloWorld")
}

test("class for name with stub class") {
val cl = new RecordedStubClassLoader(getClass().getClassLoader(), _ => true)
// scalastyle:off classforname
val cls = Class.forName("my.name.HelloWorld", false, cl)
// scalastyle:on classforname
assert(cls.getName === "my.name.HelloWorld")
assert(cl.lastStubbed === "my.name.HelloWorld")
}

test("filter class to stub") {
val list = "my.name" :: Nil
val cl = StubClassLoader(getClass().getClassLoader(), list)
val cls = cl.findClass("my.name.HelloWorld")
assert(cls.getName === "my.name.HelloWorld")

intercept[ClassNotFoundException] {
cl.findClass("name.my.GoodDay")
}
}

test("stub missing class") {
val sysClassLoader = getClass.getClassLoader()
val stubClassLoader = new RecordedStubClassLoader(null, _ => true)

// Install artifact without class A.
val sessionClassLoader =
new ChildFirstURLClassLoader(Array(udfNoAJar), stubClassLoader, sysClassLoader)
// Load udf with A used in the same class.
loadDummyUdf(sessionClassLoader)
// Class A should be stubbed.
assert(stubClassLoader.lastStubbed === classA)
}

test("unload stub class") {
val sysClassLoader = getClass.getClassLoader()
val stubClassLoader = new RecordedStubClassLoader(null, _ => true)

val cl1 = new ChildFirstURLClassLoader(Array.empty, stubClassLoader, sysClassLoader)

// Failed to load DummyUdf
intercept[Exception] {
loadDummyUdf(cl1)
}
// Successfully stubbed the missing class.
assert(stubClassLoader.lastStubbed === classDummyUdf)

// Creating a new class loader will unpack the udf correctly.
val cl2 = new ChildFirstURLClassLoader(
Array(udfNoAJar),
stubClassLoader, // even with the same stub class loader.
sysClassLoader)
// Should be able to load after the artifact is added
loadDummyUdf(cl2)
}

test("throw no such method if trying to access methods on stub class") {
val sysClassLoader = getClass.getClassLoader()
val stubClassLoader = new RecordedStubClassLoader(null, _ => true)

val sessionClassLoader =
new ChildFirstURLClassLoader(Array.empty, stubClassLoader, sysClassLoader)

// Failed to load DummyUdf because of missing methods
assert(intercept[NoSuchMethodException] {
loadDummyUdf(sessionClassLoader)
}.getMessage.contains(classDummyUdf))
// Successfully stubbed the missing class.
assert(stubClassLoader.lastStubbed === classDummyUdf)
}

private def loadDummyUdf(sessionClassLoader: ClassLoader): Unit = {
// Load DummyUdf and call a method on it.
// scalastyle:off classforname
val cls = Class.forName(classDummyUdf, false, sessionClassLoader)
// scalastyle:on classforname
cls.getDeclaredMethod("dummy")

// Load class A used inside DummyUdf
// scalastyle:off classforname
Class.forName(classA, false, sessionClassLoader)
// scalastyle:on classforname
}
}

class RecordedStubClassLoader(parent: ClassLoader, shouldStub: String => Boolean)
extends StubClassLoader(parent, shouldStub) {
var lastStubbed: String = _

override def findClass(name: String): Class[_] = {
if (shouldStub(name)) {
lastStubbed = name
}
super.findClass(name)
}
}
Loading