Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import java.util.concurrent.{Executors, Semaphore, TimeUnit}
import scala.util.Properties

import org.apache.commons.io.output.ByteArrayOutputStream
import org.apache.commons.lang3.{JavaVersion, SystemUtils}
import org.scalatest.BeforeAndAfterEach

import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, RemoteSparkSession}
Expand Down Expand Up @@ -51,29 +50,26 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach {
}

override def beforeAll(): Unit = {
// TODO(SPARK-44121) Remove this check condition
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest these changes could be merged into 3.5 to avoid future conflicts. WDYT @dongjoon-hyun

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least the current one related to connect.

Copy link
Member Author

@dongjoon-hyun dongjoon-hyun Aug 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this connect area, I'll follow your guideline because you are the original author, @LuciferYang . :)

if (SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) {
super.beforeAll()
ammoniteOut = new ByteArrayOutputStream()
testSuiteOut = new PipedOutputStream()
// Connect the `testSuiteOut` and `ammoniteIn` pipes
ammoniteIn = new PipedInputStream(testSuiteOut)
errorStream = new ByteArrayOutputStream()

val args = Array("--port", serverPort.toString)
val task = new Runnable {
override def run(): Unit = {
ConnectRepl.doMain(
args = args,
semaphore = Some(semaphore),
inputStream = ammoniteIn,
outputStream = ammoniteOut,
errorStream = errorStream)
}
super.beforeAll()
ammoniteOut = new ByteArrayOutputStream()
testSuiteOut = new PipedOutputStream()
// Connect the `testSuiteOut` and `ammoniteIn` pipes
ammoniteIn = new PipedInputStream(testSuiteOut)
errorStream = new ByteArrayOutputStream()

val args = Array("--port", serverPort.toString)
val task = new Runnable {
override def run(): Unit = {
ConnectRepl.doMain(
args = args,
semaphore = Some(semaphore),
inputStream = ammoniteIn,
outputStream = ammoniteOut,
errorStream = errorStream)
}

executorService.submit(task)
}

executorService.submit(task)
}

override def afterAll(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ import java.util.concurrent.TimeUnit

import scala.io.Source

import org.apache.commons.lang3.{JavaVersion, SystemUtils}
import org.scalactic.source.Position
import org.scalatest.{BeforeAndAfterAll, Tag}
import org.scalatest.BeforeAndAfterAll
import sys.process._

import org.apache.spark.sql.SparkSession
Expand Down Expand Up @@ -180,44 +178,41 @@ trait RemoteSparkSession extends ConnectFunSuite with BeforeAndAfterAll {
protected lazy val serverPort: Int = port

override def beforeAll(): Unit = {
// TODO(SPARK-44121) Remove this check condition
if (SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17)) {
super.beforeAll()
SparkConnectServerUtils.start()
spark = SparkSession
.builder()
.client(SparkConnectClient.builder().port(serverPort).build())
.create()

// Retry and wait for the server to start
val stop = System.nanoTime() + TimeUnit.MINUTES.toNanos(1) // ~1 min
var sleepInternalMs = TimeUnit.SECONDS.toMillis(1) // 1s with * 2 backoff
var success = false
val error = new RuntimeException(s"Failed to start the test server on port $serverPort.")

while (!success && System.nanoTime() < stop) {
try {
// Run a simple query to verify the server is really up and ready
val result = spark
.sql("select val from (values ('Hello'), ('World')) as t(val)")
.collect()
assert(result.length == 2)
success = true
debug("Spark Connect Server is up.")
} catch {
// ignored the error
case e: Throwable =>
error.addSuppressed(e)
Thread.sleep(sleepInternalMs)
sleepInternalMs *= 2
}
super.beforeAll()
SparkConnectServerUtils.start()
spark = SparkSession
.builder()
.client(SparkConnectClient.builder().port(serverPort).build())
.create()

// Retry and wait for the server to start
val stop = System.nanoTime() + TimeUnit.MINUTES.toNanos(1) // ~1 min
var sleepInternalMs = TimeUnit.SECONDS.toMillis(1) // 1s with * 2 backoff
var success = false
val error = new RuntimeException(s"Failed to start the test server on port $serverPort.")

while (!success && System.nanoTime() < stop) {
try {
// Run a simple query to verify the server is really up and ready
val result = spark
.sql("select val from (values ('Hello'), ('World')) as t(val)")
.collect()
assert(result.length == 2)
success = true
debug("Spark Connect Server is up.")
} catch {
// ignored the error
case e: Throwable =>
error.addSuppressed(e)
Thread.sleep(sleepInternalMs)
sleepInternalMs *= 2
}
}

// Throw error if failed
if (!success) {
debug(error)
throw error
}
// Throw error if failed
if (!success) {
debug(error)
throw error
}
}

Expand All @@ -230,17 +225,4 @@ trait RemoteSparkSession extends ConnectFunSuite with BeforeAndAfterAll {
spark = null
super.afterAll()
}

/**
* SPARK-44259: override test function to skip `RemoteSparkSession-based` tests as default, we
* should delete this function after SPARK-44121 is completed.
*/
override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, please ignore my previous comment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem at all. Thank you for review!

pos: Position): Unit = {
super.test(testName, testTags: _*) {
// TODO(SPARK-44121) Re-enable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
testFun
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import scala.collection.JavaConverters._

import com.google.protobuf.ByteString
import io.grpc.stub.StreamObserver
import org.apache.commons.lang3.{JavaVersion, SystemUtils}

import org.apache.spark.SparkFunSuite
import org.apache.spark.connect.proto
Expand Down Expand Up @@ -479,8 +478,6 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
}

test("transform LocalRelation") {
// TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
val rows = (0 until 10).map { i =>
InternalRow(i, UTF8String.fromString(s"str-$i"), InternalRow(i))
}
Expand Down Expand Up @@ -582,8 +579,6 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
}

test("transform UnresolvedStar and ExpressionString") {
// TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
val sql =
"SELECT * FROM VALUES (1,'spark',1), (2,'hadoop',2), (3,'kafka',3) AS tab(id, name, value)"
val input = proto.Relation
Expand Down Expand Up @@ -620,8 +615,6 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
}

test("transform UnresolvedStar with target field") {
// TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
val rows = (0 until 10).map { i =>
InternalRow(InternalRow(InternalRow(i, i + 1)))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import java.nio.file.{Files, Paths}
import scala.collection.JavaConverters._

import com.google.protobuf.ByteString
import org.apache.commons.lang3.{JavaVersion, SystemUtils}

import org.apache.spark.{SparkClassNotFoundException, SparkIllegalArgumentException}
import org.apache.spark.connect.proto
Expand Down Expand Up @@ -695,8 +694,6 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
}

test("WriteTo with create") {
// TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
withTable("testcat.table_name") {
spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)

Expand Down Expand Up @@ -724,8 +721,6 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
}

test("WriteTo with create and using") {
// TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
val defaultOwnership = Map(TableCatalog.PROP_OWNER -> Utils.getCurrentUserName())
withTable("testcat.table_name") {
spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
Expand Down Expand Up @@ -763,8 +758,6 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
}

test("WriteTo with append") {
// TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
withTable("testcat.table_name") {
spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)

Expand Down Expand Up @@ -796,8 +789,6 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
}

test("WriteTo with overwrite") {
// TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
withTable("testcat.table_name") {
spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)

Expand Down Expand Up @@ -851,8 +842,6 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
}

test("WriteTo with overwritePartitions") {
// TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
withTable("testcat.table_name") {
spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import io.grpc.stub.StreamObserver
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.{BigIntVector, Float8Vector}
import org.apache.arrow.vector.ipc.ArrowStreamReader
import org.apache.commons.lang3.{JavaVersion, SystemUtils}
import org.mockito.Mockito.when
import org.scalatest.Tag
import org.scalatestplus.mockito.MockitoSugar
Expand Down Expand Up @@ -157,8 +156,6 @@ class SparkConnectServiceSuite

test("SPARK-41224: collect data using arrow") {
withEvents { verifyEvents =>
// TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
val instance = new SparkConnectService(false)
val connect = new MockRemoteSession()
val context = proto.UserContext
Expand Down Expand Up @@ -246,8 +243,6 @@ class SparkConnectServiceSuite

test("SPARK-44776: LocalTableScanExec") {
withEvents { verifyEvents =>
// TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
val instance = new SparkConnectService(false)
val connect = new MockRemoteSession()
val context = proto.UserContext
Expand Down Expand Up @@ -319,8 +314,6 @@ class SparkConnectServiceSuite
// Set 10 KiB as the batch size limit
val batchSize = 10 * 1024
withSparkConf("spark.connect.grpc.arrow.maxBatchSize" -> batchSize.toString) {
// TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
val instance = new SparkConnectService(false)
val connect = new MockRemoteSession()
val context = proto.UserContext
Expand Down Expand Up @@ -742,8 +735,6 @@ class SparkConnectServiceSuite
}

test("Test observe response") {
// TODO(SPARK-44121) Renable Arrow-based connect tests in Java 21
assume(SystemUtils.isJavaVersionAtMost(JavaVersion.JAVA_17))
withTable("test") {
spark.sql("""
| CREATE TABLE test (col1 INT, col2 STRING)
Expand Down