Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
package org.apache.spark.sql.application

import java.io.{PipedInputStream, PipedOutputStream}
import java.nio.file.Paths
import java.util.concurrent.{Executors, Semaphore, TimeUnit}

import scala.util.Properties

import org.apache.commons.io.output.ByteArrayOutputStream
import org.scalatest.BeforeAndAfterEach

import org.apache.spark.sql.connect.client.util.RemoteSparkSession
import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSparkSession}

class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach {

Expand All @@ -35,6 +38,11 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach {
private var ammoniteIn: PipedInputStream = _
private val semaphore: Semaphore = new Semaphore(0)

private val scalaVersion = Properties.versionNumberString
.split("\\.")
.take(2)
.mkString(".")

private def getCleanString(out: ByteArrayOutputStream): String = {
// Remove ANSI colour codes
// Regex taken from https://stackoverflow.com/a/25189932
Expand Down Expand Up @@ -96,7 +104,10 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach {

def assertContains(message: String, output: String): Unit = {
val isContain = output.contains(message)
assert(isContain, "Ammonite output did not contain '" + message + "':\n" + output)
assert(
isContain,
"Ammonite output did not contain '" + message + "':\n" + output +
s"\nError Output: ${getCleanString(errorStream)}")
}

test("Simple query") {
Expand Down Expand Up @@ -151,4 +162,33 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach {
assertContains("Array[java.lang.Long] = Array(0L, 2L, 4L, 6L, 8L)", output)
}

test("Client-side JAR") {
// scalastyle:off classforname line.size.limit
val sparkHome = IntegrationTestUtils.sparkHome
val testJar = Paths
.get(
s"$sparkHome/connector/connect/client/jvm/src/test/resources/TestHelloV2_$scalaVersion.jar")
.toFile

assert(testJar.exists(), "Missing TestHelloV2 jar!")
val input = s"""
|import java.nio.file.Paths
|def classLoadingTest(x: Int): Int = {
| val classloader =
| Option(Thread.currentThread().getContextClassLoader).getOrElse(getClass.getClassLoader)
| val cls = Class.forName("com.example.Hello$$", true, classloader)
| val module = cls.getField("MODULE$$").get(null)
| cls.getMethod("test").invoke(module).asInstanceOf[Int]
|}
|val classLoaderUdf = udf(classLoadingTest _)
|
|val jarPath = Paths.get("${testJar.toString}").toUri
|spark.addArtifact(jarPath)
|
|spark.range(5).select(classLoaderUdf(col("id"))).as[Int].collect()
""".stripMargin
val output = runCommandsInShell(input)
assertContains("Array[Int] = Array(2, 2, 2, 2, 2)", output)
// scalastyle:on classforname line.size.limit
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging
Files.move(serverLocalStagingPath, target)
if (remoteRelativePath.startsWith(s"jars${File.separator}")) {
jarsList.add(target)
jarsURI.add(artifactURI + "/" + target.toString)
jarsURI.add(artifactURI + "/" + remoteRelativePath.toString)
} else if (remoteRelativePath.startsWith(s"pyfiles${File.separator}")) {
sessionHolder.session.sparkContext.addFile(target.toString)
val stringRemotePath = remoteRelativePath.toString
Expand Down