diff --git a/connector/connect/client/jvm/src/test/resources/TestHelloV2_2.12.jar b/connector/connect/client/jvm/src/test/resources/TestHelloV2_2.12.jar new file mode 100644 index 000000000000..d89cf6543a20 Binary files /dev/null and b/connector/connect/client/jvm/src/test/resources/TestHelloV2_2.12.jar differ diff --git a/connector/connect/client/jvm/src/test/resources/TestHelloV2_2.13.jar b/connector/connect/client/jvm/src/test/resources/TestHelloV2_2.13.jar new file mode 100644 index 000000000000..6dee8fcd9c95 Binary files /dev/null and b/connector/connect/client/jvm/src/test/resources/TestHelloV2_2.13.jar differ diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala index 61959234c879..720f66680ee1 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.application import java.io.{PipedInputStream, PipedOutputStream} +import java.nio.file.Paths import java.util.concurrent.{Executors, Semaphore, TimeUnit} +import scala.util.Properties + import org.apache.commons.io.output.ByteArrayOutputStream import org.scalatest.BeforeAndAfterEach -import org.apache.spark.sql.connect.client.util.RemoteSparkSession +import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSparkSession} class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { @@ -35,6 +38,11 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { private var ammoniteIn: PipedInputStream = _ private val semaphore: Semaphore = new Semaphore(0) + private val scalaVersion = Properties.versionNumberString + .split("\\.") + .take(2) + .mkString(".") + private def getCleanString(out: ByteArrayOutputStream): String = { // Remove ANSI colour codes // Regex taken from https://stackoverflow.com/a/25189932 @@ -96,7 +104,10 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { def assertContains(message: String, output: String): Unit = { val isContain = output.contains(message) - assert(isContain, "Ammonite output did not contain '" + message + "':\n" + output) + assert( + isContain, + "Ammonite output did not contain '" + message + "':\n" + output + + s"\nError Output: ${getCleanString(errorStream)}") } test("Simple query") { @@ -151,4 +162,33 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { assertContains("Array[java.lang.Long] = Array(0L, 2L, 4L, 6L, 8L)", output) } + test("Client-side JAR") { + // scalastyle:off classforname line.size.limit + val sparkHome = IntegrationTestUtils.sparkHome + val testJar = Paths + .get( + s"$sparkHome/connector/connect/client/jvm/src/test/resources/TestHelloV2_$scalaVersion.jar") + .toFile + + assert(testJar.exists(), "Missing TestHelloV2 jar!") + val input = s""" + |import java.nio.file.Paths + |def classLoadingTest(x: Int): Int = { + | val classloader = + | Option(Thread.currentThread().getContextClassLoader).getOrElse(getClass.getClassLoader) + | val cls = Class.forName("com.example.Hello$$", true, classloader) + | val module = cls.getField("MODULE$$").get(null) + | cls.getMethod("test").invoke(module).asInstanceOf[Int] + |} + |val classLoaderUdf = udf(classLoadingTest _) + | + |val jarPath = Paths.get("${testJar.toString}").toUri + |spark.addArtifact(jarPath) + | + |spark.range(5).select(classLoaderUdf(col("id"))).as[Int].collect() + """.stripMargin + val output = runCommandsInShell(input) + assertContains("Array[Int] = Array(2, 2, 2, 2, 2)", output) + // scalastyle:on classforname line.size.limit + } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala index 0a91c6b95502..9fd8e367e4aa 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala @@ -133,7 +133,7 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging Files.move(serverLocalStagingPath, target) if (remoteRelativePath.startsWith(s"jars${File.separator}")) { jarsList.add(target) - jarsURI.add(artifactURI + "/" + target.toString) + jarsURI.add(artifactURI + "/" + remoteRelativePath.toString) } else if (remoteRelativePath.startsWith(s"pyfiles${File.separator}")) { sessionHolder.session.sparkContext.addFile(target.toString) val stringRemotePath = remoteRelativePath.toString