diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 50ff08f997cb..f211f84962f1 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -196,13 +196,14 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp signal.wait() result = partitions.remove(currentPartitionId) } - error match { - case NonFatal(e) => - responseObserver.onError(error) - logError("Error while processing query.", e) - return - case fatal: Throwable => throw fatal - case null => result.get + if (error == null) { + result.get + } else if (NonFatal(error)) { + responseObserver.onError(error) + logError("Error while processing query.", error) + return + } else { + throw error } } diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index 133ce980ecdd..e7520d5416bc 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -16,10 +16,12 @@ */ package org.apache.spark.sql.connect.planner -import scala.concurrent.Promise -import scala.concurrent.duration._ +import scala.collection.mutable import io.grpc.stub.StreamObserver +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.{BigIntVector, Float8Vector} +import org.apache.arrow.vector.ipc.ArrowStreamReader import org.apache.spark.SparkException import org.apache.spark.connect.proto @@ -28,7 +30,6 @@ import org.apache.spark.sql.connect.dsl.plans._ import org.apache.spark.sql.connect.service.SparkConnectService import org.apache.spark.sql.execution.ExplainMode import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.util.ThreadUtils /** * Testing Connect Service implementation. @@ -67,6 +68,78 @@ class SparkConnectServiceSuite extends SharedSparkSession { } } + test("collect data using arrow") { + val instance = new SparkConnectService(false) + val connect = new MockRemoteSession() + val context = proto.UserContext + .newBuilder() + .setUserId("c1") + .build() + val plan = proto.Plan + .newBuilder() + .setRoot(connect.sql("select id, exp(id) as eid from range(0, 100, 1, 4)")) + .build() + val request = proto.ExecutePlanRequest + .newBuilder() + .setPlan(plan) + .setUserContext(context) + .build() + + // Execute plan. + @volatile var done = false + val responses = mutable.Buffer.empty[proto.ExecutePlanResponse] + instance.executePlan( + request, + new StreamObserver[proto.ExecutePlanResponse] { + override def onNext(v: proto.ExecutePlanResponse): Unit = responses += v + + override def onError(throwable: Throwable): Unit = throw throwable + + override def onCompleted(): Unit = done = true + }) + + // The current implementation is expected to be blocking. This is here to make sure it is. + assert(done) + + // 4 Partitions + Metrics + assert(responses.size == 5) + + // Make sure the last response is metrics only + val last = responses.last + assert(last.hasMetrics && !(last.hasArrowBatch || last.hasJsonBatch)) + + val allocator = new RootAllocator() + + // Check the 'data' batches + var expectedId = 0L + var previousEId = 0.0d + responses.dropRight(1).foreach { response => + assert(response.hasArrowBatch) + val batch = response.getArrowBatch + assert(batch.getData != null) + assert(batch.getRowCount == 25) + + val reader = new ArrowStreamReader(batch.getData.newInput(), allocator) + while (reader.loadNextBatch()) { + val root = reader.getVectorSchemaRoot + val idVector = root.getVector(0).asInstanceOf[BigIntVector] + val eidVector = root.getVector(1).asInstanceOf[Float8Vector] + val numRows = root.getRowCount + var i = 0 + while (i < numRows) { + assert(idVector.get(i) == expectedId) + expectedId += 1 + val eid = eidVector.get(i) + assert(eid > previousEId) + previousEId = eid + i += 1 + } + } + reader.close() + } + allocator.close() + } + test("SPARK-41165: failures in the arrow collect path should not cause hangs") { val instance = new SparkConnectService(false) @@ -92,21 +165,23 @@ class SparkConnectServiceSuite extends SharedSparkSession { .setUserContext(context) .build() - val promise = Promise[Seq[proto.ExecutePlanResponse]] + // The observer is executed inside this thread. So + // we can perform the checks inside the observer. instance.executePlan( request, new StreamObserver[proto.ExecutePlanResponse] { - private val responses = Seq.newBuilder[proto.ExecutePlanResponse] + override def onNext(v: proto.ExecutePlanResponse): Unit = { + fail("this should not receive responses") + } - override def onNext(v: proto.ExecutePlanResponse): Unit = responses += v - - override def onError(throwable: Throwable): Unit = promise.failure(throwable) + override def onError(throwable: Throwable): Unit = { + assert(throwable.isInstanceOf[SparkException]) + } - override def onCompleted(): Unit = promise.success(responses.result()) + override def onCompleted(): Unit = { + fail("this should not complete") + } }) - intercept[SparkException] { - ThreadUtils.awaitResult(promise.future, 2.seconds) - } } test("Test explain mode in analyze response") { diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index fb455b7f3f43..4f26ac8a06fd 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -220,7 +220,6 @@ def test_create_global_temp_view(self): with self.assertRaises(_MultiThreadedRendezvous): self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1") - @unittest.skip("test_fill_na is flaky") def test_fill_na(self): # SPARK-41128: Test fill na query = """