diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 092bdd00dc1c..b5d100e894d0 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.service import scala.collection.JavaConverters._ +import scala.util.control.NonFatal import com.google.protobuf.ByteString import io.grpc.stub.StreamObserver @@ -32,6 +33,7 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ThreadUtils class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResponse]) extends Logging { @@ -71,20 +73,83 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp var numSent = 0 if (numPartitions > 0) { + type Batch = (Array[Byte], Long) + val batches = rows.mapPartitionsInternal( SparkConnectStreamHandler .rowToArrowConverter(schema, maxRecordsPerBatch, maxBatchSize, timeZoneId)) - batches.collect().foreach { case (bytes, count) => - val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId) - val batch = proto.ExecutePlanResponse.ArrowBatch - .newBuilder() - .setRowCount(count) - .setData(ByteString.copyFrom(bytes)) - .build() - response.setArrowBatch(batch) - responseObserver.onNext(response.build()) - numSent += 1 + val signal = new Object + val partitions = new Array[Array[Batch]](numPartitions) + var error: Option[Throwable] = None + + // This callback is executed by the DAGScheduler thread. + // After fetching a partition, it inserts the partition into the Map, and then + // wakes up the main thread. + val resultHandler = (partitionId: Int, partition: Array[Batch]) => { + signal.synchronized { + partitions(partitionId) = partition + signal.notify() + } + () + } + + val future = spark.sparkContext.submitJob( + rdd = batches, + processPartition = (iter: Iterator[Batch]) => iter.toArray, + partitions = Seq.range(0, numPartitions), + resultHandler = resultHandler, + resultFunc = () => ()) + + // Collect errors and propagate them to the main thread. + future.onComplete { result => + result.failed.foreach { throwable => + signal.synchronized { + error = Some(throwable) + signal.notify() + } + } + }(ThreadUtils.sameThread) + + // The main thread will wait until 0-th partition is available, + // then send it to client and wait for the next partition. + // Different from the implementation of [[Dataset#collectAsArrowToPython]], it sends + // the arrow batches in main thread to avoid DAGScheduler thread been blocked for + // tasks not related to scheduling. This is particularly important if there are + // multiple users or clients running code at the same time. + var currentPartitionId = 0 + while (currentPartitionId < numPartitions) { + val partition = signal.synchronized { + var part = partitions(currentPartitionId) + while (part == null && error.isEmpty) { + signal.wait() + part = partitions(currentPartitionId) + } + partitions(currentPartitionId) = null + + error.foreach { + case NonFatal(e) => + responseObserver.onError(e) + logError("Error while processing query.", e) + return + case other => throw other + } + part + } + + partition.foreach { case (bytes, count) => + val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId) + val batch = proto.ExecutePlanResponse.ArrowBatch + .newBuilder() + .setRowCount(count) + .setData(ByteString.copyFrom(bytes)) + .build() + response.setArrowBatch(batch) + responseObserver.onNext(response.build()) + numSent += 1 + } + + currentPartitionId += 1 } } @@ -126,7 +191,7 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp object SparkConnectStreamHandler { type Batch = (Array[Byte], Long) - private[service] def rowToArrowConverter( + private def rowToArrowConverter( schema: StructType, maxRecordsPerBatch: Int, maxBatchSize: Long, diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index 133ce980ecdd..5f18b0d45c53 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -16,10 +16,12 @@ */ package org.apache.spark.sql.connect.planner -import scala.concurrent.Promise -import scala.concurrent.duration._ +import scala.collection.mutable import io.grpc.stub.StreamObserver +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.{BigIntVector, Float8Vector} +import org.apache.arrow.vector.ipc.ArrowStreamReader import org.apache.spark.SparkException import org.apache.spark.connect.proto @@ -28,7 +30,6 @@ import org.apache.spark.sql.connect.dsl.plans._ import org.apache.spark.sql.connect.service.SparkConnectService import org.apache.spark.sql.execution.ExplainMode import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.util.ThreadUtils /** * Testing Connect Service implementation. @@ -67,6 +68,78 @@ class SparkConnectServiceSuite extends SharedSparkSession { } } + test("SPARK-41224: collect data using arrow") { + val instance = new SparkConnectService(false) + val connect = new MockRemoteSession() + val context = proto.UserContext + .newBuilder() + .setUserId("c1") + .build() + val plan = proto.Plan + .newBuilder() + .setRoot(connect.sql("select id, exp(id) as eid from range(0, 100, 1, 4)")) + .build() + val request = proto.ExecutePlanRequest + .newBuilder() + .setPlan(plan) + .setUserContext(context) + .build() + + // Execute plan. + @volatile var done = false + val responses = mutable.Buffer.empty[proto.ExecutePlanResponse] + instance.executePlan( + request, + new StreamObserver[proto.ExecutePlanResponse] { + override def onNext(v: proto.ExecutePlanResponse): Unit = responses += v + + override def onError(throwable: Throwable): Unit = throw throwable + + override def onCompleted(): Unit = done = true + }) + + // The current implementation is expected to be blocking. This is here to make sure it is. + assert(done) + + // 4 Partitions + Metrics + assert(responses.size == 5) + + // Make sure the last response is metrics only + val last = responses.last + assert(last.hasMetrics && !last.hasArrowBatch) + + val allocator = new RootAllocator() + + // Check the 'data' batches + var expectedId = 0L + var previousEId = 0.0d + responses.dropRight(1).foreach { response => + assert(response.hasArrowBatch) + val batch = response.getArrowBatch + assert(batch.getData != null) + assert(batch.getRowCount == 25) + + val reader = new ArrowStreamReader(batch.getData.newInput(), allocator) + while (reader.loadNextBatch()) { + val root = reader.getVectorSchemaRoot + val idVector = root.getVector(0).asInstanceOf[BigIntVector] + val eidVector = root.getVector(1).asInstanceOf[Float8Vector] + val numRows = root.getRowCount + var i = 0 + while (i < numRows) { + assert(idVector.get(i) == expectedId) + expectedId += 1 + val eid = eidVector.get(i) + assert(eid > previousEId) + previousEId = eid + i += 1 + } + } + reader.close() + } + allocator.close() + } + test("SPARK-41165: failures in the arrow collect path should not cause hangs") { val instance = new SparkConnectService(false) @@ -92,21 +165,23 @@ class SparkConnectServiceSuite extends SharedSparkSession { .setUserContext(context) .build() - val promise = Promise[Seq[proto.ExecutePlanResponse]] + // The observer is executed inside this thread. So + // we can perform the checks inside the observer. instance.executePlan( request, new StreamObserver[proto.ExecutePlanResponse] { - private val responses = Seq.newBuilder[proto.ExecutePlanResponse] + override def onNext(v: proto.ExecutePlanResponse): Unit = { + fail("this should not receive responses") + } - override def onNext(v: proto.ExecutePlanResponse): Unit = responses += v - - override def onError(throwable: Throwable): Unit = promise.failure(throwable) + override def onError(throwable: Throwable): Unit = { + assert(throwable.isInstanceOf[SparkException]) + } - override def onCompleted(): Unit = promise.success(responses.result()) + override def onCompleted(): Unit = { + fail("this should not complete") + } }) - intercept[SparkException] { - ThreadUtils.awaitResult(promise.future, 2.seconds) - } } test("Test explain mode in analyze response") {