diff --git a/connector/connect/bin/spark-connect-build b/connector/connect/bin/spark-connect-build
index ca8d4cf6e9005..63c17d8f7aa06 100755
--- a/connector/connect/bin/spark-connect-build
+++ b/connector/connect/bin/spark-connect-build
@@ -29,7 +29,5 @@ SCALA_BINARY_VER=`grep "scala.binary.version" "${SPARK_HOME}/pom.xml" | head -n1
SCALA_VER=`grep "scala.version" "${SPARK_HOME}/pom.xml" | grep ${SCALA_BINARY_VER} | head -n1 | awk -F '[<>]' '{print $3}'`
SCALA_ARG="-Pscala-${SCALA_BINARY_VER}"
-# Build the jars needed for spark submit and spark connect
-build/sbt "${SCALA_ARG}" -Phive -Pconnect package || exit 1
-# Build the jars needed for spark connect JVM client
-build/sbt "${SCALA_ARG}" "sql/package;connect-client-jvm/assembly" || exit 1
+# Build the jars needed for spark submit and spark connect JVM client
+build/sbt "${SCALA_ARG}" -Phive -Pconnect package "connect-client-jvm/package" || exit 1
diff --git a/connector/connect/bin/spark-connect-scala-client b/connector/connect/bin/spark-connect-scala-client
index ef394df4e0f25..ffa77f708421f 100755
--- a/connector/connect/bin/spark-connect-scala-client
+++ b/connector/connect/bin/spark-connect-scala-client
@@ -45,7 +45,7 @@ SCALA_ARG="-Pscala-${SCALA_BINARY_VER}"
SCBUILD="${SCBUILD:-1}"
if [ "$SCBUILD" -eq "1" ]; then
# Build the jars needed for spark connect JVM client
- build/sbt "${SCALA_ARG}" "sql/package;connect-client-jvm/assembly" || exit 1
+ build/sbt "${SCALA_ARG}" "connect-client-jvm/package" || exit 1
fi
if [ -z "$SCCLASSPATH" ]; then
diff --git a/connector/connect/bin/spark-connect-scala-client-classpath b/connector/connect/bin/spark-connect-scala-client-classpath
index 99a22f3d5ffeb..9d33e90bf09cb 100755
--- a/connector/connect/bin/spark-connect-scala-client-classpath
+++ b/connector/connect/bin/spark-connect-scala-client-classpath
@@ -30,6 +30,5 @@ SCALA_VER=`grep "scala.version" "${SPARK_HOME}/pom.xml" | grep ${SCALA_BINARY_VE
SCALA_ARG="-Pscala-${SCALA_BINARY_VER}"
CONNECT_CLASSPATH="$(build/sbt "${SCALA_ARG}" -DcopyDependencies=false "export connect-client-jvm/fullClasspath" | grep jar | tail -n1)"
-SQL_CLASSPATH="$(build/sbt "${SCALA_ARG}" -DcopyDependencies=false "export sql/fullClasspath" | grep jar | tail -n1)"
-echo "$CONNECT_CLASSPATH:$CLASSPATH"
+echo "$CONNECT_CLASSPATH"
diff --git a/connector/connect/client/jvm/pom.xml b/connector/connect/client/jvm/pom.xml
index 04eed736fba05..d4e9b147e0264 100644
--- a/connector/connect/client/jvm/pom.xml
+++ b/connector/connect/client/jvm/pom.xml
@@ -113,11 +113,6 @@
scalacheck_${scala.binary.version}
test
-
- org.mockito
- mockito-core
- test
-
com.typesafe
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ToStub.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ToStub.scala
new file mode 100644
index 0000000000000..9a5fda1189d2d
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ToStub.scala
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.client
+
+/**
+ * Class used to test stubbing. This needs to be in the main source tree, because this is not
+ * synced with the connect server during tests.
+ */
+case class ToStub(value: Long)
diff --git a/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java b/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java
index 6e5fb72d4964b..d5fdede774f47 100644
--- a/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java
+++ b/connector/connect/client/jvm/src/test/java/org/apache/spark/sql/JavaEncoderSuite.java
@@ -27,9 +27,8 @@
import static org.apache.spark.sql.Encoders.*;
import static org.apache.spark.sql.functions.*;
import static org.apache.spark.sql.RowFactory.create;
-import org.apache.spark.sql.connect.client.SparkConnectClient;
-import org.apache.spark.sql.connect.client.util.SparkConnectServerUtils;
import org.apache.spark.api.java.function.MapFunction;
+import org.apache.spark.sql.test.SparkConnectServerUtils;
import org.apache.spark.sql.types.StructType;
/**
@@ -40,14 +39,7 @@ public class JavaEncoderSuite implements Serializable {
@BeforeClass
public static void setup() {
- SparkConnectServerUtils.start();
- spark = SparkSession
- .builder()
- .client(SparkConnectClient
- .builder()
- .port(SparkConnectServerUtils.port())
- .build())
- .create();
+ spark = SparkConnectServerUtils.createSparkSession();
}
@AfterClass
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala
index fa97498f7e77a..cefa63ecd353e 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala
@@ -22,7 +22,7 @@ import java.io.{File, FilenameFilter}
import org.apache.commons.io.FileUtils
import org.apache.spark.SparkException
-import org.apache.spark.sql.connect.client.util.RemoteSparkSession
+import org.apache.spark.sql.test.{RemoteSparkSession, SQLHelper}
import org.apache.spark.sql.types.{DoubleType, LongType, StructType}
import org.apache.spark.storage.StorageLevel
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala
index 2f4e1aa9bd093..069d8ec502f52 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala
@@ -22,7 +22,7 @@ import java.util.Random
import org.scalatest.matchers.must.Matchers._
import org.apache.spark.{SparkException, SparkIllegalArgumentException}
-import org.apache.spark.sql.connect.client.util.RemoteSparkSession
+import org.apache.spark.sql.test.RemoteSparkSession
class ClientDataFrameStatSuite extends RemoteSparkSession {
private def toLetter(i: Int): String = (i + 97).toChar.toString
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala
index ae20f771d6c1f..a521c6745a90c 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDatasetSuite.scala
@@ -26,8 +26,8 @@ import org.scalatest.BeforeAndAfterEach
import org.apache.spark.connect.proto
import org.apache.spark.sql.connect.client.{DummySparkConnectService, SparkConnectClient}
-import org.apache.spark.sql.connect.client.util.ConnectFunSuite
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.ConnectFunSuite
// Add sample tests.
// - sample fraction: simple.sample(0.1)
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index 74debd2a72d2a..f10f5c78ead14 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -36,10 +36,10 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkResult}
-import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSparkSession}
-import org.apache.spark.sql.connect.client.util.SparkConnectServerUtils.port
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SqlApiConf
+import org.apache.spark.sql.test.{IntegrationTestUtils, RemoteSparkSession, SQLHelper}
+import org.apache.spark.sql.test.SparkConnectServerUtils.port
import org.apache.spark.sql.types._
class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateMethodTester {
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala
index 0d361fe1007f7..a88d6ec116a42 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala
@@ -21,7 +21,7 @@ import java.io.ByteArrayOutputStream
import scala.collection.JavaConverters._
import org.apache.spark.sql.{functions => fn}
-import org.apache.spark.sql.connect.client.util.ConnectFunSuite
+import org.apache.spark.sql.test.ConnectFunSuite
import org.apache.spark.sql.types._
/**
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala
index ac64d4411a866..393fa19fa70b4 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionSuite.scala
@@ -19,8 +19,8 @@ package org.apache.spark.sql
import scala.collection.JavaConverters._
-import org.apache.spark.sql.connect.client.util.QueryTest
import org.apache.spark.sql.internal.SqlApiConf
+import org.apache.spark.sql.test.{QueryTest, SQLHelper}
import org.apache.spark.sql.types.{StringType, StructType}
class DataFrameNaFunctionSuite extends QueryTest with SQLHelper {
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala
index 4a8e108357fa7..78cc26d627c7c 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/FunctionTestSuite.scala
@@ -21,9 +21,9 @@ import java.util.Collections
import scala.collection.JavaConverters._
import org.apache.spark.sql.avro.{functions => avroFn}
-import org.apache.spark.sql.connect.client.util.ConnectFunSuite
import org.apache.spark.sql.functions._
import org.apache.spark.sql.protobuf.{functions => pbFn}
+import org.apache.spark.sql.test.ConnectFunSuite
import org.apache.spark.sql.types.{DataType, StructType}
/**
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
index 380ca2fb72b31..3e979be73a754 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/KeyValueGroupedDatasetE2ETestSuite.scala
@@ -20,9 +20,9 @@ import java.sql.Timestamp
import java.util.Arrays
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Append
-import org.apache.spark.sql.connect.client.util.QueryTest
import org.apache.spark.sql.functions._
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout}
+import org.apache.spark.sql.test.{QueryTest, SQLHelper}
import org.apache.spark.sql.types._
case class ClickEvent(id: String, timestamp: Timestamp)
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index 11d0696b6e130..4916ff1f59743 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -37,11 +37,10 @@ import org.apache.spark.sql.avro.{functions => avroFn}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder
import org.apache.spark.sql.connect.client.SparkConnectClient
-import org.apache.spark.sql.connect.client.util.ConnectFunSuite
-import org.apache.spark.sql.connect.client.util.IntegrationTestUtils
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.protobuf.{functions => pbFn}
+import org.apache.spark.sql.test.{ConnectFunSuite, IntegrationTestUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.SparkFileUtils
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala
index 6db38bfb1c304..680380c91a0c2 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala
@@ -28,7 +28,7 @@ import org.scalatest.BeforeAndAfterAll
import org.apache.spark.sql.connect.client.SparkConnectClient
import org.apache.spark.sql.connect.client.arrow.{ArrowDeserializers, ArrowSerializer}
-import org.apache.spark.sql.connect.client.util.ConnectFunSuite
+import org.apache.spark.sql.test.ConnectFunSuite
/**
* Test suite for SQL implicits.
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala
index 490bdf9cd86ec..c76dc724828e5 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionE2ESuite.scala
@@ -26,7 +26,7 @@ import scala.util.{Failure, Success}
import org.scalatest.concurrent.Eventually._
import org.apache.spark.SparkException
-import org.apache.spark.sql.connect.client.util.RemoteSparkSession
+import org.apache.spark.sql.test.RemoteSparkSession
import org.apache.spark.util.SparkThreadUtils.awaitResult
/**
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala
index 4aa8b4360eebd..90fe8f57d0713 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala
@@ -22,7 +22,7 @@ import scala.util.control.NonFatal
import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, MethodDescriptor}
-import org.apache.spark.sql.connect.client.util.ConnectFunSuite
+import org.apache.spark.sql.test.ConnectFunSuite
/**
* Tests for non-dataframe related SparkSession operations.
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/StubbingTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/StubbingTestSuite.scala
new file mode 100644
index 0000000000000..b9c5888e5cb77
--- /dev/null
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/StubbingTestSuite.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import org.apache.spark.sql.connect.client.ToStub
+import org.apache.spark.sql.test.RemoteSparkSession
+
+class StubbingTestSuite extends RemoteSparkSession {
+ private def eval[T](f: => T): T = f
+
+ test("capture of to-be stubbed class") {
+ val session = spark
+ import session.implicits._
+ val result = spark
+ .range(0, 10, 1, 1)
+ .map(n => n + 1)
+ .as[ToStub]
+ .head()
+ eval {
+ assert(result.value == 1)
+ }
+ }
+}
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/UDFClassLoadingE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UDFClassLoadingE2ESuite.scala
similarity index 94%
rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/UDFClassLoadingE2ESuite.scala
rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UDFClassLoadingE2ESuite.scala
index 8fdb7efbcba71..a76e046db2e3a 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/UDFClassLoadingE2ESuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UDFClassLoadingE2ESuite.scala
@@ -14,17 +14,16 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.sql.connect.client
+package org.apache.spark.sql
import java.io.File
import java.nio.file.{Files, Paths}
import scala.util.Properties
-import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.connect.client.util.RemoteSparkSession
import org.apache.spark.sql.connect.common.ProtoDataTypes
import org.apache.spark.sql.expressions.ScalarUserDefinedFunction
+import org.apache.spark.sql.test.RemoteSparkSession
class UDFClassLoadingE2ESuite extends RemoteSparkSession {
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
index d00659ac2d8eb..0af8c78a1da85 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
@@ -26,8 +26,8 @@ import scala.collection.JavaConverters._
import org.apache.spark.api.java.function._
import org.apache.spark.sql.api.java.UDF2
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveIntEncoder, PrimitiveLongEncoder}
-import org.apache.spark.sql.connect.client.util.QueryTest
import org.apache.spark.sql.functions.{col, struct, udf}
+import org.apache.spark.sql.test.QueryTest
import org.apache.spark.sql.types.IntegerType
/**
@@ -215,33 +215,31 @@ class UserDefinedFunctionE2ETestSuite extends QueryTest {
}
test("Dataset foreachPartition") {
- val sum = new AtomicLong()
val func: Iterator[JLong] => Unit = f => {
+ val sum = new AtomicLong()
f.foreach(v => sum.addAndGet(v))
- // The value should be 45
- assert(sum.get() == -1)
+ throw new Exception("Success, processed records: " + sum.get())
}
val exception = intercept[Exception] {
spark.range(10).repartition(1).foreachPartition(func)
}
- assert(exception.getMessage.contains("45 did not equal -1"))
+ assert(exception.getMessage.contains("Success, processed records: 45"))
}
test("Dataset foreachPartition - java") {
val sum = new AtomicLong()
val exception = intercept[Exception] {
spark
- .range(10)
+ .range(11)
.repartition(1)
.foreachPartition(new ForeachPartitionFunction[JLong] {
override def call(t: JIterator[JLong]): Unit = {
t.asScala.foreach(v => sum.addAndGet(v))
- // The value should be 45
- assert(sum.get() == -1)
+ throw new Exception("Success, processed records: " + sum.get())
}
})
}
- assert(exception.getMessage.contains("45 did not equal -1"))
+ assert(exception.getMessage.contains("Success, processed records: 55"))
}
test("Dataset foreach: change not visible to client") {
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala
index 76608559866fa..923aa5af75ba8 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionSuite.scala
@@ -20,9 +20,9 @@ import scala.reflect.runtime.universe.typeTag
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.connect.client.util.ConnectFunSuite
import org.apache.spark.sql.connect.common.UdfPacket
import org.apache.spark.sql.functions.udf
+import org.apache.spark.sql.test.ConnectFunSuite
import org.apache.spark.util.SparkSerDeUtils
class UserDefinedFunctionSuite extends ConnectFunSuite {
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
index 5a909ab8b4178..4106d298dbe2b 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
@@ -25,7 +25,7 @@ import scala.util.Properties
import org.apache.commons.io.output.ByteArrayOutputStream
import org.scalatest.BeforeAndAfterEach
-import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSparkSession}
+import org.apache.spark.sql.test.{IntegrationTestUtils, RemoteSparkSession}
class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach {
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala
index 7901008bc12e1..770143f2e9b4e 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala
@@ -30,7 +30,7 @@ import org.scalatest.BeforeAndAfterEach
import org.apache.spark.connect.proto.AddArtifactsRequest
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
-import org.apache.spark.sql.connect.client.util.ConnectFunSuite
+import org.apache.spark.sql.test.ConnectFunSuite
class ArtifactSuite extends ConnectFunSuite with BeforeAndAfterEach {
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 3d7a80b1fb61b..1100babde795f 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -24,7 +24,7 @@ import java.util.regex.Pattern
import com.typesafe.tools.mima.core._
import com.typesafe.tools.mima.lib.MiMaLib
-import org.apache.spark.sql.connect.client.util.IntegrationTestUtils._
+import org.apache.spark.sql.test.IntegrationTestUtils._
/**
* A tool for checking the binary compatibility of the connect client API against the spark SQL
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala
index 625d4cf43e1a8..ca23436675f87 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ClassFinderSuite.scala
@@ -20,7 +20,7 @@ import java.nio.file.Paths
import org.apache.commons.io.FileUtils
-import org.apache.spark.sql.connect.client.util.ConnectFunSuite
+import org.apache.spark.sql.test.ConnectFunSuite
import org.apache.spark.util.SparkFileUtils
class ClassFinderSuite extends ConnectFunSuite {
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala
index 1dc1fd567ec1a..e1d4a18d0ff60 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientBuilderParseTestSuite.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.connect.client
import java.util.UUID
-import org.apache.spark.sql.connect.client.util.ConnectFunSuite
+import org.apache.spark.sql.test.ConnectFunSuite
/**
* Test suite for [[SparkConnectClient.Builder]] parsing and configuration.
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
index 6348e0e49ca3d..80e245ec78b7d 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
@@ -31,8 +31,8 @@ import org.apache.spark.SparkException
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse, AnalyzePlanRequest, AnalyzePlanResponse, ArtifactStatusesRequest, ArtifactStatusesResponse, ExecutePlanRequest, ExecutePlanResponse, SparkConnectServiceGrpc}
import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.connect.client.util.ConnectFunSuite
import org.apache.spark.sql.connect.common.config.ConnectCommon
+import org.apache.spark.sql.test.ConnectFunSuite
class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach {
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
index 2a499cc548fd8..b6ad27d3e5287 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
@@ -43,7 +43,7 @@ import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils._
import org.apache.spark.sql.catalyst.util.SparkIntervalUtils._
import org.apache.spark.sql.connect.client.CloseableIterator
import org.apache.spark.sql.connect.client.arrow.FooEnum.FooEnum
-import org.apache.spark.sql.connect.client.util.ConnectFunSuite
+import org.apache.spark.sql.test.ConnectFunSuite
import org.apache.spark.sql.types.{ArrayType, DataType, DayTimeIntervalType, Decimal, DecimalType, IntegerType, Metadata, SQLUserDefinedType, StructType, UserDefinedType, YearMonthIntervalType}
/**
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
deleted file mode 100644
index 33540bf498535..0000000000000
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
+++ /dev/null
@@ -1,228 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.sql.connect.client.util
-
-import java.io.{BufferedOutputStream, File}
-import java.util.concurrent.TimeUnit
-
-import scala.io.Source
-
-import org.scalatest.BeforeAndAfterAll
-import sys.process._
-
-import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.connect.client.SparkConnectClient
-import org.apache.spark.sql.connect.client.util.IntegrationTestUtils._
-import org.apache.spark.sql.connect.common.config.ConnectCommon
-
-/**
- * An util class to start a local spark connect server in a different process for local E2E tests.
- * Pre-running the tests, the spark connect artifact needs to be built using e.g. `build/sbt
- * package`. It is designed to start the server once but shared by all tests. It is equivalent to
- * use the following command to start the connect server via command line:
- *
- * {{{
- * bin/spark-shell \
- * --jars `ls connector/connect/server/target/**/spark-connect*SNAPSHOT.jar | paste -sd ',' -` \
- * --conf spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin
- * }}}
- *
- * Set system property `spark.test.home` or env variable `SPARK_HOME` if the test is not executed
- * from the Spark project top folder. Set system property `spark.debug.sc.jvm.client=true` to
- * print the server process output in the console to debug server start stop problems.
- */
-object SparkConnectServerUtils {
-
- // Server port
- val port: Int =
- ConnectCommon.CONNECT_GRPC_BINDING_PORT + util.Random.nextInt(1000)
-
- @volatile private var stopped = false
-
- private var consoleOut: BufferedOutputStream = _
- private val serverStopCommand = "q"
-
- private lazy val sparkConnect: Process = {
- debug("Starting the Spark Connect Server...")
- val connectJar = findJar(
- "connector/connect/server",
- "spark-connect-assembly",
- "spark-connect").getCanonicalPath
-
- val builder = Process(
- Seq(
- "bin/spark-submit",
- "--driver-class-path",
- connectJar,
- "--conf",
- s"spark.connect.grpc.binding.port=$port") ++ testConfigs ++ debugConfigs ++ Seq(
- "--class",
- "org.apache.spark.sql.connect.SimpleSparkConnectService",
- connectJar),
- new File(sparkHome))
-
- val io = new ProcessIO(
- in => consoleOut = new BufferedOutputStream(in),
- out => Source.fromInputStream(out).getLines.foreach(debug),
- err => Source.fromInputStream(err).getLines.foreach(debug))
- val process = builder.run(io)
-
- // Adding JVM shutdown hook
- sys.addShutdownHook(stop())
- process
- }
-
- /**
- * As one shared spark will be started for all E2E tests, for tests that needs some special
- * configs, we add them here
- */
- private def testConfigs: Seq[String] = {
- // To find InMemoryTableCatalog for V2 writer tests
- val catalystTestJar =
- tryFindJar("sql/catalyst", "spark-catalyst", "spark-catalyst", test = true)
- .map(clientTestJar => Seq(clientTestJar.getCanonicalPath))
- .getOrElse(Seq.empty)
-
- // For UDF maven E2E tests, the server needs the client code to find the UDFs defined in tests.
- val connectClientTestJar = tryFindJar(
- "connector/connect/client/jvm",
- // SBT passes the client & test jars to the server process automatically.
- // So we skip building or finding this jar for SBT.
- "sbt-tests-do-not-need-this-jar",
- "spark-connect-client-jvm",
- test = true)
- .map(clientTestJar => Seq(clientTestJar.getCanonicalPath))
- .getOrElse(Seq.empty)
-
- val allJars = catalystTestJar ++ connectClientTestJar
- val jarsConfigs = Seq("--jars", allJars.mkString(","))
-
- // Use InMemoryTableCatalog for V2 writer tests
- val writerV2Configs = Seq(
- "--conf",
- "spark.sql.catalog.testcat=org.apache.spark.sql.connector.catalog.InMemoryTableCatalog")
-
- // Run tests using hive
- val hiveTestConfigs = {
- val catalogImplementation = if (IntegrationTestUtils.isSparkHiveJarAvailable) {
- "hive"
- } else {
- // scalastyle:off println
- println(
- "Will start Spark Connect server with `spark.sql.catalogImplementation=in-memory`, " +
- "some tests that rely on Hive will be ignored. If you don't want to skip them:\n" +
- "1. Test with maven: run `build/mvn install -DskipTests -Phive` before testing\n" +
- "2. Test with sbt: run test with `-Phive` profile")
- // scalastyle:on println
- // SPARK-43647: Proactively cleaning the `classes` and `test-classes` dir of hive
- // module to avoid unexpected loading of `DataSourceRegister` in hive module during
- // testing without `-Phive` profile.
- IntegrationTestUtils.cleanUpHiveClassesDirIfNeeded()
- "in-memory"
- }
- Seq("--conf", s"spark.sql.catalogImplementation=$catalogImplementation")
- }
-
- // Make the server terminate reattachable streams every 1 second and 123 bytes,
- // to make the tests exercise reattach.
- val reattachExecuteConfigs = Seq(
- "--conf",
- "spark.connect.execute.reattachable.senderMaxStreamDuration=1s",
- "--conf",
- "spark.connect.execute.reattachable.senderMaxStreamSize=123")
-
- jarsConfigs ++ writerV2Configs ++ hiveTestConfigs ++ reattachExecuteConfigs
- }
-
- def start(): Unit = {
- assert(!stopped)
- sparkConnect
- }
-
- def stop(): Int = {
- stopped = true
- debug("Stopping the Spark Connect Server...")
- try {
- consoleOut.write(serverStopCommand.getBytes)
- consoleOut.flush()
- consoleOut.close()
- } catch {
- case e: Throwable =>
- debug(e)
- sparkConnect.destroy()
- }
-
- val code = sparkConnect.exitValue()
- debug(s"Spark Connect Server is stopped with exit code: $code")
- code
- }
-}
-
-trait RemoteSparkSession extends ConnectFunSuite with BeforeAndAfterAll {
- import SparkConnectServerUtils._
- var spark: SparkSession = _
- protected lazy val serverPort: Int = port
-
- override def beforeAll(): Unit = {
- super.beforeAll()
- SparkConnectServerUtils.start()
- spark = SparkSession
- .builder()
- .client(SparkConnectClient.builder().port(serverPort).build())
- .create()
-
- // Retry and wait for the server to start
- val stop = System.nanoTime() + TimeUnit.MINUTES.toNanos(1) // ~1 min
- var sleepInternalMs = TimeUnit.SECONDS.toMillis(1) // 1s with * 2 backoff
- var success = false
- val error = new RuntimeException(s"Failed to start the test server on port $serverPort.")
-
- while (!success && System.nanoTime() < stop) {
- try {
- // Run a simple query to verify the server is really up and ready
- val result = spark
- .sql("select val from (values ('Hello'), ('World')) as t(val)")
- .collect()
- assert(result.length == 2)
- success = true
- debug("Spark Connect Server is up.")
- } catch {
- // ignored the error
- case e: Throwable =>
- error.addSuppressed(e)
- Thread.sleep(sleepInternalMs)
- sleepInternalMs *= 2
- }
- }
-
- // Throw error if failed
- if (!success) {
- debug(error)
- throw error
- }
- }
-
- override def afterAll(): Unit = {
- try {
- if (spark != null) spark.stop()
- } catch {
- case e: Throwable => debug(e)
- }
- spark = null
- super.afterAll()
- }
-}
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
index 944a999a860b6..dc4d441ec3015 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala
@@ -29,11 +29,11 @@ import org.scalatest.time.SpanSugar._
import org.apache.spark.api.java.function.VoidFunction2
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{DataFrame, ForeachWriter, Row, SparkSession, SQLHelper}
-import org.apache.spark.sql.connect.client.util.QueryTest
+import org.apache.spark.sql.{DataFrame, ForeachWriter, Row, SparkSession}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.window
import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, QueryStartedEvent, QueryTerminatedEvent}
+import org.apache.spark.sql.test.{QueryTest, SQLHelper}
import org.apache.spark.util.SparkFileUtils
class ClientStreamingQuerySuite extends QueryTest with SQLHelper with Logging {
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala
index cdb6b9a2e9c10..2fab6e8e3c843 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateStreamingSuite.scala
@@ -23,9 +23,9 @@ import org.scalatest.concurrent.Eventually.eventually
import org.scalatest.concurrent.Futures.timeout
import org.scalatest.time.SpanSugar._
-import org.apache.spark.sql.{SparkSession, SQLHelper}
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Append
-import org.apache.spark.sql.connect.client.util.QueryTest
+import org.apache.spark.sql.test.{QueryTest, SQLHelper}
import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType}
case class ClickEvent(id: String, timestamp: Timestamp)
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryProgressSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryProgressSuite.scala
index 2911e4e016efa..aed6c55b3e7fb 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryProgressSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryProgressSuite.scala
@@ -21,7 +21,7 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
-import org.apache.spark.sql.connect.client.util.ConnectFunSuite
+import org.apache.spark.sql.test.ConnectFunSuite
import org.apache.spark.sql.types.StructType
class StreamingQueryProgressSuite extends ConnectFunSuite {
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/ConnectFunSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/ConnectFunSuite.scala
similarity index 97%
rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/ConnectFunSuite.scala
rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/ConnectFunSuite.scala
index 0a1e794c8e72e..8d69d91a34f7d 100755
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/ConnectFunSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/ConnectFunSuite.scala
@@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.sql.connect.client.util
+package org.apache.spark.sql.test
import java.nio.file.Path
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/IntegrationTestUtils.scala
similarity index 87%
rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala
rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/IntegrationTestUtils.scala
index 4d88565308f2d..61d08912aec23 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/IntegrationTestUtils.scala
@@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.sql.connect.client.util
+package org.apache.spark.sql.test
import java.io.File
import java.nio.file.{Files, Paths}
@@ -30,8 +30,12 @@ object IntegrationTestUtils {
// System properties used for testing and debugging
private val DEBUG_SC_JVM_CLIENT = "spark.debug.sc.jvm.client"
+ private val DEBUG_SC_JVM_CLIENT_ENV = "SPARK_DEBUG_SC_JVM_CLIENT"
// Enable this flag to print all server logs to the console
- private[connect] val isDebug = System.getProperty(DEBUG_SC_JVM_CLIENT, "false").toBoolean
+ private[sql] val isDebug = {
+ System.getProperty(DEBUG_SC_JVM_CLIENT, "false").toBoolean ||
+ Option(System.getenv(DEBUG_SC_JVM_CLIENT_ENV)).exists(_.toBoolean)
+ }
private[sql] lazy val scalaVersion = {
versionNumberString.split('.') match {
@@ -49,8 +53,14 @@ object IntegrationTestUtils {
sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME"))
}
- private[connect] def debugConfigs: Seq[String] = {
- val log4j2 = s"$sparkHome/connector/connect/client/jvm/src/test/resources/log4j2.properties"
+ private[sql] lazy val connectClientHomeDir = s"$sparkHome/connector/connect/client/jvm"
+
+ private[sql] lazy val connectClientTestClassDir = {
+ s"$connectClientHomeDir/target/$scalaDir/test-classes"
+ }
+
+ private[sql] def debugConfigs: Seq[String] = {
+ val log4j2 = s"$connectClientHomeDir/src/test/resources/log4j2.properties"
if (isDebug) {
Seq(
// Enable to see the server plan change log
@@ -70,9 +80,9 @@ object IntegrationTestUtils {
// Log server start stop debug info into console
// scalastyle:off println
- private[connect] def debug(msg: String): Unit = if (isDebug) println(msg)
+ private[sql] def debug(msg: String): Unit = if (isDebug) println(msg)
// scalastyle:on println
- private[connect] def debug(error: Throwable): Unit = if (isDebug) error.printStackTrace()
+ private[sql] def debug(error: Throwable): Unit = if (isDebug) error.printStackTrace()
private[sql] lazy val isSparkHiveJarAvailable: Boolean = {
val filePath = s"$sparkHome/assembly/target/$scalaDir/jars/" +
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/QueryTest.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/QueryTest.scala
similarity index 99%
rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/QueryTest.scala
rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/QueryTest.scala
index a0d3d4368dd28..adbd8286090d9 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/QueryTest.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/QueryTest.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.connect.client.util
+package org.apache.spark.sql.test
import java.util.TimeZone
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala
new file mode 100644
index 0000000000000..8a8f739a7c502
--- /dev/null
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.test
+
+import java.io.{File, IOException, OutputStream}
+import java.lang.ProcessBuilder
+import java.lang.ProcessBuilder.Redirect
+import java.nio.file.Paths
+import java.util.concurrent.TimeUnit
+
+import scala.concurrent.duration.FiniteDuration
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.SparkBuildInfo
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.connect.client.GrpcRetryHandler.RetryPolicy
+import org.apache.spark.sql.connect.client.SparkConnectClient
+import org.apache.spark.sql.connect.common.config.ConnectCommon
+import org.apache.spark.sql.test.IntegrationTestUtils._
+
+/**
+ * An util class to start a local spark connect server in a different process for local E2E tests.
+ * Pre-running the tests, the spark connect artifact needs to be built using e.g. `build/sbt
+ * package`. It is designed to start the server once but shared by all tests. It is equivalent to
+ * use the following command to start the connect server via command line:
+ *
+ * {{{
+ * bin/spark-shell \
+ * --jars `ls connector/connect/server/target/**/spark-connect*SNAPSHOT.jar | paste -sd ',' -` \
+ * --conf spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin
+ * }}}
+ *
+ * Set system property `spark.test.home` or env variable `SPARK_HOME` if the test is not executed
+ * from the Spark project top folder. Set system property `spark.debug.sc.jvm.client=true` or
+ * environment variable `SPARK_DEBUG_SC_JVM_CLIENT=true` to print the server process output in the
+ * console to debug server start stop problems.
+ */
+object SparkConnectServerUtils {
+
+ // Server port
+ val port: Int =
+ ConnectCommon.CONNECT_GRPC_BINDING_PORT + util.Random.nextInt(1000)
+
+ @volatile private var stopped = false
+
+ private var consoleOut: OutputStream = _
+ private val serverStopCommand = "q"
+
+ private lazy val sparkConnect: java.lang.Process = {
+ debug("Starting the Spark Connect Server...")
+ val connectJar = findJar(
+ "connector/connect/server",
+ "spark-connect-assembly",
+ "spark-connect").getCanonicalPath
+
+ val command = Seq.newBuilder[String]
+ command += "bin/spark-submit"
+ command += "--driver-class-path" += connectJar
+ command += "--class" += "org.apache.spark.sql.connect.SimpleSparkConnectService"
+ command += "--conf" += s"spark.connect.grpc.binding.port=$port"
+ command ++= testConfigs
+ command ++= debugConfigs
+ command += connectJar
+ val builder = new ProcessBuilder(command.result(): _*)
+ builder.directory(new File(sparkHome))
+ val environment = builder.environment()
+ environment.remove("SPARK_DIST_CLASSPATH")
+ if (isDebug) {
+ builder.redirectError(Redirect.INHERIT)
+ builder.redirectOutput(Redirect.INHERIT)
+ }
+
+ val process = builder.start()
+ consoleOut = process.getOutputStream
+
+ // Adding JVM shutdown hook
+ sys.addShutdownHook(stop())
+ process
+ }
+
+ /**
+ * As one shared spark will be started for all E2E tests, for tests that needs some special
+ * configs, we add them here
+ */
+ private def testConfigs: Seq[String] = {
+ // To find InMemoryTableCatalog for V2 writer tests
+ val catalystTestJar =
+ findJar("sql/catalyst", "spark-catalyst", "spark-catalyst", test = true).getCanonicalPath
+
+ val catalogImplementation = if (IntegrationTestUtils.isSparkHiveJarAvailable) {
+ "hive"
+ } else {
+ // scalastyle:off println
+ println(
+ "Will start Spark Connect server with `spark.sql.catalogImplementation=in-memory`, " +
+ "some tests that rely on Hive will be ignored. If you don't want to skip them:\n" +
+ "1. Test with maven: run `build/mvn install -DskipTests -Phive` before testing\n" +
+ "2. Test with sbt: run test with `-Phive` profile")
+ // scalastyle:on println
+ // SPARK-43647: Proactively cleaning the `classes` and `test-classes` dir of hive
+ // module to avoid unexpected loading of `DataSourceRegister` in hive module during
+ // testing without `-Phive` profile.
+ IntegrationTestUtils.cleanUpHiveClassesDirIfNeeded()
+ "in-memory"
+ }
+ val confs = Seq(
+ // Use InMemoryTableCatalog for V2 writer tests
+ "spark.sql.catalog.testcat=org.apache.spark.sql.connector.catalog.InMemoryTableCatalog",
+ // Try to use the hive catalog, fallback to in-memory if it is not there.
+ "spark.sql.catalogImplementation=" + catalogImplementation,
+ // Make the server terminate reattachable streams every 1 second and 123 bytes,
+ // to make the tests exercise reattach.
+ "spark.connect.execute.reattachable.senderMaxStreamDuration=1s",
+ "spark.connect.execute.reattachable.senderMaxStreamSize=123",
+ // Disable UI
+ "spark.ui.enabled=false")
+ Seq("--jars", catalystTestJar) ++ confs.flatMap(v => "--conf" :: v :: Nil)
+ }
+
+ def start(): Unit = {
+ assert(!stopped)
+ sparkConnect
+ }
+
+ def stop(): Int = {
+ stopped = true
+ debug("Stopping the Spark Connect Server...")
+ try {
+ consoleOut.write(serverStopCommand.getBytes)
+ consoleOut.flush()
+ consoleOut.close()
+ if (!sparkConnect.waitFor(2, TimeUnit.SECONDS)) {
+ sparkConnect.destroyForcibly()
+ }
+ val code = sparkConnect.exitValue()
+ debug(s"Spark Connect Server is stopped with exit code: $code")
+ code
+ } catch {
+ case e: IOException if e.getMessage.contains("Stream closed") =>
+ -1
+ case e: Throwable =>
+ debug(e)
+ sparkConnect.destroyForcibly()
+ throw e
+ }
+ }
+
+ def syncTestDependencies(spark: SparkSession): Unit = {
+ // Both SBT & Maven pass the test-classes as a directory instead of a jar.
+ val testClassesPath = Paths.get(IntegrationTestUtils.connectClientTestClassDir)
+ spark.client.artifactManager.addClassDir(testClassesPath)
+
+ // We need scalatest & scalactic on the session's classpath to make the tests work.
+ val jars = System
+ .getProperty("java.class.path")
+ .split(File.pathSeparatorChar)
+ .filter { e: String =>
+ val fileName = e.substring(e.lastIndexOf(File.separatorChar) + 1)
+ fileName.endsWith(".jar") &&
+ (fileName.startsWith("scalatest") || fileName.startsWith("scalactic"))
+ }
+ .map(e => Paths.get(e).toUri)
+ spark.client.artifactManager.addArtifacts(jars)
+ }
+
+ def createSparkSession(): SparkSession = {
+ SparkConnectServerUtils.start()
+
+ val spark = SparkSession
+ .builder()
+ .client(
+ SparkConnectClient
+ .builder()
+ .userId("test")
+ .port(port)
+ .retryPolicy(RetryPolicy(maxRetries = 7, maxBackoff = FiniteDuration(10, "s")))
+ .build())
+ .create()
+
+ // Execute an RPC which will get retried until the server is up.
+ assert(spark.version == SparkBuildInfo.spark_version)
+
+ // Auto-sync dependencies.
+ SparkConnectServerUtils.syncTestDependencies(spark)
+
+ spark
+ }
+}
+
+trait RemoteSparkSession extends ConnectFunSuite with BeforeAndAfterAll {
+ import SparkConnectServerUtils._
+ var spark: SparkSession = _
+ protected lazy val serverPort: Int = port
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ spark = createSparkSession()
+ }
+
+ override def afterAll(): Unit = {
+ try {
+ if (spark != null) spark.stop()
+ } catch {
+ case e: Throwable => debug(e)
+ }
+ spark = null
+ super.afterAll()
+ }
+}
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLHelper.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/SQLHelper.scala
similarity index 97%
rename from connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLHelper.scala
rename to connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/SQLHelper.scala
index f357270e20fe2..12212492e370b 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLHelper.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/SQLHelper.scala
@@ -14,13 +14,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.sql
+package org.apache.spark.sql.test
import java.io.File
import java.util.UUID
import org.scalatest.Assertions.fail
+import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.util.{SparkErrorUtils, SparkFileUtils}
trait SQLHelper {
diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala
index 136a31fca3cda..b1a7746a84ad6 100644
--- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala
+++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala
@@ -145,10 +145,28 @@ class ArtifactManager(
addArtifacts(classFinders.asScala.flatMap(_.findClasses()))
}
+ private[sql] def addClassDir(base: Path): Unit = {
+ if (!Files.isDirectory(base)) {
+ return
+ }
+ val builder = Seq.newBuilder[Artifact]
+ val stream = Files.walk(base)
+ try {
+ stream.forEach { path =>
+ if (Files.isRegularFile(path) && path.toString.endsWith(".class")) {
+ builder += Artifact.newClassArtifact(base.relativize(path), new LocalFile(path))
+ }
+ }
+ } finally {
+ stream.close()
+ }
+ addArtifacts(builder.result())
+ }
+
/**
* Add a number of artifacts to the session.
*/
- private def addArtifacts(artifacts: Iterable[Artifact]): Unit = {
+ private[client] def addArtifacts(artifacts: Iterable[Artifact]): Unit = {
if (artifacts.isEmpty) {
return
}
diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
index 8b6f070b8f5b9..a6841e7f1182e 100644
--- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
+++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
@@ -26,7 +26,7 @@ import io.grpc.stub.StreamObserver
import org.apache.spark.internal.Logging
-private[client] class GrpcRetryHandler(
+private[sql] class GrpcRetryHandler(
private val retryPolicy: GrpcRetryHandler.RetryPolicy,
private val sleep: Long => Unit = Thread.sleep) {
@@ -146,7 +146,7 @@ private[client] class GrpcRetryHandler(
}
}
-private[client] object GrpcRetryHandler extends Logging {
+private[sql] object GrpcRetryHandler extends Logging {
/**
* Retries the given function with exponential backoff according to the client's retryPolicy.
diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index c41f6dfaae1fc..a0853cc0621fa 100644
--- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -58,7 +58,7 @@ private[sql] class SparkConnectClient(
// a new client will create a new session ID.
private[sql] val sessionId: String = configuration.sessionId.getOrElse(UUID.randomUUID.toString)
- private[client] val artifactManager: ArtifactManager = {
+ private[sql] val artifactManager: ArtifactManager = {
new ArtifactManager(configuration, sessionId, bstub, stub)
}
diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoDataTypes.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoDataTypes.scala
index 19890558ab212..e85a2a40da2bf 100644
--- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoDataTypes.scala
+++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/ProtoDataTypes.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.connect.common
import org.apache.spark.connect.proto
-private[connect] object ProtoDataTypes {
+private[sql] object ProtoDataTypes {
val NullType: proto.DataType = proto.DataType
.newBuilder()
diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala
index 3f594d79b627b..dca65cf905fc8 100644
--- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala
+++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/config/ConnectCommon.scala
@@ -16,7 +16,7 @@
*/
package org.apache.spark.sql.connect.common.config
-private[connect] object ConnectCommon {
+private[sql] object ConnectCommon {
val CONNECT_GRPC_BINDING_PORT: Int = 15002
val CONNECT_GRPC_MAX_MESSAGE_SIZE: Int = 128 * 1024 * 1024;
}
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
index c1dd7820c55f0..a2df11eeb5832 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
@@ -31,7 +31,7 @@ import org.apache.hadoop.fs.{LocalFileSystem, Path => FSPath}
import org.apache.spark.{JobArtifactSet, JobArtifactState, SparkContext, SparkEnv}
import org.apache.spark.internal.Logging
-import org.apache.spark.internal.config.CONNECT_SCALA_UDF_STUB_PREFIXES
+import org.apache.spark.internal.config.{CONNECT_SCALA_UDF_STUB_PREFIXES, EXECUTOR_USER_CLASS_PATH_FIRST}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.artifact.util.ArtifactUtils
import org.apache.spark.sql.connect.config.Connect.CONNECT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL
@@ -162,15 +162,37 @@ class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends Logging
*/
def classloader: ClassLoader = {
val urls = getSparkConnectAddedJars :+ classDir.toUri.toURL
- val loader = if (SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_PREFIXES).nonEmpty) {
- val stubClassLoader =
- StubClassLoader(null, SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_PREFIXES))
- new ChildFirstURLClassLoader(
- urls.toArray,
- stubClassLoader,
- Utils.getContextOrSparkClassLoader)
+ val prefixes = SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_PREFIXES)
+ val userClasspathFirst = SparkEnv.get.conf.get(EXECUTOR_USER_CLASS_PATH_FIRST)
+ val loader = if (prefixes.nonEmpty) {
+ // Two things you need to know about classloader for all of this to make sense:
+ // 1. A classloader needs to be able to fully define a class.
+ // 2. Classes are loaded lazily. Only when a class is used the classes it references are
+ // loaded.
+ // This makes stubbing a bit more complicated then you'd expect. We cannot put the stubbing
+ // classloader as a fallback at the end of the loading process, because then classes that
+ // have been found in one of the parent classloaders and that contain a reference to a
+ // missing, to-be-stubbed missing class will still fail with classloading errors later on.
+ // The way we currently fix this is by making the stubbing class loader the last classloader
+ // it delegates to.
+ if (userClasspathFirst) {
+ // USER -> SYSTEM -> STUB
+ new ChildFirstURLClassLoader(
+ urls.toArray,
+ StubClassLoader(Utils.getContextOrSparkClassLoader, prefixes))
+ } else {
+ // SYSTEM -> USER -> STUB
+ new ChildFirstURLClassLoader(
+ urls.toArray,
+ StubClassLoader(null, prefixes),
+ Utils.getContextOrSparkClassLoader)
+ }
} else {
- new URLClassLoader(urls.toArray, Utils.getContextOrSparkClassLoader)
+ if (userClasspathFirst) {
+ new ChildFirstURLClassLoader(urls.toArray, Utils.getContextOrSparkClassLoader)
+ } else {
+ new URLClassLoader(urls.toArray, Utils.getContextOrSparkClassLoader)
+ }
}
logDebug(s"Using class loader: $loader, containing urls: $urls")
diff --git a/core/src/main/scala/org/apache/spark/util/StubClassLoader.scala b/core/src/main/scala/org/apache/spark/util/StubClassLoader.scala
index e27376e2b83d8..8d903c2a3e400 100644
--- a/core/src/main/scala/org/apache/spark/util/StubClassLoader.scala
+++ b/core/src/main/scala/org/apache/spark/util/StubClassLoader.scala
@@ -18,6 +18,8 @@ package org.apache.spark.util
import org.apache.xbean.asm9.{ClassWriter, Opcodes}
+import org.apache.spark.internal.Logging
+
/**
* [[ClassLoader]] that replaces missing classes with stubs, if the cannot be found. It will only
* do this for classes that are marked for stubbing.
@@ -27,11 +29,12 @@ import org.apache.xbean.asm9.{ClassWriter, Opcodes}
* the class and therefor is safe to replace by a stub.
*/
class StubClassLoader(parent: ClassLoader, shouldStub: String => Boolean)
- extends ClassLoader(parent) {
+ extends ClassLoader(parent) with Logging {
override def findClass(name: String): Class[_] = {
if (!shouldStub(name)) {
throw new ClassNotFoundException(name)
}
+ logDebug(s"Generating stub for $name")
val bytes = StubClassLoader.generateStub(name)
defineClass(name, bytes, 0, bytes.length)
}
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index a4a76efbd34a0..2f437eeb75cc1 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -860,7 +860,6 @@ object SparkConnectClient {
"com.google.protobuf" % "protobuf-java" % protoVersion % "protobuf"
)
},
-
dependencyOverrides ++= {
val guavaVersion =
SbtPomKeys.effectivePom.value.getProperties.get(