Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.connect.service

import scala.collection.JavaConverters._
import scala.util.control.NonFatal

import com.google.protobuf.ByteString
import io.grpc.stub.StreamObserver
Expand All @@ -32,6 +33,7 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec}
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.ThreadUtils

class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResponse])
extends Logging {
Expand Down Expand Up @@ -71,20 +73,82 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
var numSent = 0

if (numPartitions > 0) {
type Batch = (Array[Byte], Long)

val batches = rows.mapPartitionsInternal(
SparkConnectStreamHandler
.rowToArrowConverter(schema, maxRecordsPerBatch, maxBatchSize, timeZoneId))

batches.collect().foreach { case (bytes, count) =>
val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId)
val batch = proto.ExecutePlanResponse.ArrowBatch
.newBuilder()
.setRowCount(count)
.setData(ByteString.copyFrom(bytes))
.build()
response.setArrowBatch(batch)
responseObserver.onNext(response.build())
numSent += 1
val signal = new Object
val partitions = new Array[Array[Batch]](numPartitions - 1)
var error: Option[Throwable] = None

// This callback is executed by the DAGScheduler thread.
// After fetching a partition, it inserts the partition into the Map, and then
// wakes up the main thread.
val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
signal.synchronized {
partitions(partitionId) = partition
signal.notify()
}
()
}

val future = spark.sparkContext.submitJob(
rdd = batches,
processPartition = (iter: Iterator[Batch]) => iter.toArray,
partitions = Seq.range(0, numPartitions),
resultHandler = resultHandler,
resultFunc = () => ())

// Collect errors and propagate them to the main thread.
future.onComplete { result =>
result.failed.foreach { throwable =>
signal.synchronized {
error = Some(throwable)
signal.notify()
}
}
}(ThreadUtils.sameThread)

// The main thread will wait until 0-th partition is available,
// then send it to client and wait for the next partition.
// Different from the implementation of [[Dataset#collectAsArrowToPython]], it sends
// the arrow batches in main thread to avoid DAGScheduler thread been blocked for
// tasks not related to scheduling. This is particularly important if there are
// multiple users or clients running code at the same time.
var currentPartitionId = 0
while (currentPartitionId < numPartitions) {
val partition = signal.synchronized {
val result = partitions(currentPartitionId)
while (result != null && error.isEmpty) {
signal.wait()
}
partitions(currentPartitionId) = null

error match {
case None => result
case Some(NonFatal(e)) =>
responseObserver.onError(e)
logError("Error while processing query.", e)
return
case Some(other) => throw other
}
}

partition.foreach { case (bytes, count) =>
val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId)
val batch = proto.ExecutePlanResponse.ArrowBatch
.newBuilder()
.setRowCount(count)
.setData(ByteString.copyFrom(bytes))
.build()
response.setArrowBatch(batch)
responseObserver.onNext(response.build())
numSent += 1
}

currentPartitionId += 1
}
}

Expand Down Expand Up @@ -126,7 +190,7 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
object SparkConnectStreamHandler {
type Batch = (Array[Byte], Long)

private[service] def rowToArrowConverter(
private def rowToArrowConverter(
schema: StructType,
maxRecordsPerBatch: Int,
maxBatchSize: Long,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
*/
package org.apache.spark.sql.connect.planner

import scala.concurrent.Promise
import scala.concurrent.duration._
import scala.collection.mutable

import io.grpc.stub.StreamObserver
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.{BigIntVector, Float8Vector}
import org.apache.arrow.vector.ipc.ArrowStreamReader

import org.apache.spark.SparkException
import org.apache.spark.connect.proto
Expand All @@ -28,7 +30,6 @@ import org.apache.spark.sql.connect.dsl.plans._
import org.apache.spark.sql.connect.service.SparkConnectService
import org.apache.spark.sql.execution.ExplainMode
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.ThreadUtils

/**
* Testing Connect Service implementation.
Expand Down Expand Up @@ -67,6 +68,78 @@ class SparkConnectServiceSuite extends SharedSparkSession {
}
}

test("SPARK-41224: collect data using arrow") {
val instance = new SparkConnectService(false)
val connect = new MockRemoteSession()
val context = proto.UserContext
.newBuilder()
.setUserId("c1")
.build()
val plan = proto.Plan
.newBuilder()
.setRoot(connect.sql("select id, exp(id) as eid from range(0, 100, 1, 4)"))
.build()
val request = proto.ExecutePlanRequest
.newBuilder()
.setPlan(plan)
.setUserContext(context)
.build()

// Execute plan.
@volatile var done = false
val responses = mutable.Buffer.empty[proto.ExecutePlanResponse]
instance.executePlan(
request,
new StreamObserver[proto.ExecutePlanResponse] {
override def onNext(v: proto.ExecutePlanResponse): Unit = responses += v

override def onError(throwable: Throwable): Unit = throw throwable

override def onCompleted(): Unit = done = true
})

// The current implementation is expected to be blocking. This is here to make sure it is.
assert(done)

// 4 Partitions + Metrics
assert(responses.size == 5)

// Make sure the last response is metrics only
val last = responses.last
assert(last.hasMetrics && !(last.hasArrowBatch || last.hasJsonBatch))

val allocator = new RootAllocator()

// Check the 'data' batches
var expectedId = 0L
var previousEId = 0.0d
responses.dropRight(1).foreach { response =>
assert(response.hasArrowBatch)
val batch = response.getArrowBatch
assert(batch.getData != null)
assert(batch.getRowCount == 25)

val reader = new ArrowStreamReader(batch.getData.newInput(), allocator)
while (reader.loadNextBatch()) {
val root = reader.getVectorSchemaRoot
val idVector = root.getVector(0).asInstanceOf[BigIntVector]
val eidVector = root.getVector(1).asInstanceOf[Float8Vector]
val numRows = root.getRowCount
var i = 0
while (i < numRows) {
assert(idVector.get(i) == expectedId)
expectedId += 1
val eid = eidVector.get(i)
assert(eid > previousEId)
previousEId = eid
i += 1
}
}
reader.close()
}
allocator.close()
}

test("SPARK-41165: failures in the arrow collect path should not cause hangs") {
val instance = new SparkConnectService(false)

Expand All @@ -92,21 +165,23 @@ class SparkConnectServiceSuite extends SharedSparkSession {
.setUserContext(context)
.build()

val promise = Promise[Seq[proto.ExecutePlanResponse]]
// The observer is executed inside this thread. So
// we can perform the checks inside the observer.
instance.executePlan(
request,
new StreamObserver[proto.ExecutePlanResponse] {
private val responses = Seq.newBuilder[proto.ExecutePlanResponse]
override def onNext(v: proto.ExecutePlanResponse): Unit = {
fail("this should not receive responses")
}

override def onNext(v: proto.ExecutePlanResponse): Unit = responses += v

override def onError(throwable: Throwable): Unit = promise.failure(throwable)
override def onError(throwable: Throwable): Unit = {
assert(throwable.isInstanceOf[SparkException])
}

override def onCompleted(): Unit = promise.success(responses.result())
override def onCompleted(): Unit = {
fail("this should not complete")
}
})
intercept[SparkException] {
ThreadUtils.awaitResult(promise.future, 2.seconds)
}
}

test("Test explain mode in analyze response") {
Expand Down