Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,14 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
signal.wait()
result = partitions.remove(currentPartitionId)
}
error match {
case NonFatal(e) =>
responseObserver.onError(error)
logError("Error while processing query.", e)
return
case fatal: Throwable => throw fatal
case null => result.get
if (error == null) {
result.get
} else if (NonFatal(error)) {
responseObserver.onError(error)
logError("Error while processing query.", error)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory, logging first then responding might be better in case something goes wrong and onError would throw.

return
} else {
throw error
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
*/
package org.apache.spark.sql.connect.planner

import scala.concurrent.Promise
import scala.concurrent.duration._
import scala.collection.mutable

import io.grpc.stub.StreamObserver
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.{BigIntVector, Float8Vector}
import org.apache.arrow.vector.ipc.ArrowStreamReader

import org.apache.spark.SparkException
import org.apache.spark.connect.proto
Expand All @@ -28,7 +30,6 @@ import org.apache.spark.sql.connect.dsl.plans._
import org.apache.spark.sql.connect.service.SparkConnectService
import org.apache.spark.sql.execution.ExplainMode
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.ThreadUtils

/**
* Testing Connect Service implementation.
Expand Down Expand Up @@ -67,6 +68,78 @@ class SparkConnectServiceSuite extends SharedSparkSession {
}
}

test("collect data using arrow") {
val instance = new SparkConnectService(false)
val connect = new MockRemoteSession()
val context = proto.UserContext
.newBuilder()
.setUserId("c1")
.build()
val plan = proto.Plan
.newBuilder()
.setRoot(connect.sql("select id, exp(id) as eid from range(0, 100, 1, 4)"))
.build()
val request = proto.ExecutePlanRequest
.newBuilder()
.setPlan(plan)
.setUserContext(context)
.build()

// Execute plan.
@volatile var done = false
val responses = mutable.Buffer.empty[proto.ExecutePlanResponse]
instance.executePlan(
request,
new StreamObserver[proto.ExecutePlanResponse] {
override def onNext(v: proto.ExecutePlanResponse): Unit = responses += v

override def onError(throwable: Throwable): Unit = throw throwable

override def onCompleted(): Unit = done = true
})

// The current implementation is expected to be blocking. This is here to make sure it is.
assert(done)

// 4 Partitions + Metrics
assert(responses.size == 5)

// Make sure the last response is metrics only
val last = responses.last
assert(last.hasMetrics && !(last.hasArrowBatch || last.hasJsonBatch))

val allocator = new RootAllocator()

// Check the 'data' batches
var expectedId = 0L
var previousEId = 0.0d
responses.dropRight(1).foreach { response =>
assert(response.hasArrowBatch)
val batch = response.getArrowBatch
assert(batch.getData != null)
assert(batch.getRowCount == 25)

val reader = new ArrowStreamReader(batch.getData.newInput(), allocator)
while (reader.loadNextBatch()) {
val root = reader.getVectorSchemaRoot
val idVector = root.getVector(0).asInstanceOf[BigIntVector]
val eidVector = root.getVector(1).asInstanceOf[Float8Vector]
val numRows = root.getRowCount
var i = 0
while (i < numRows) {
assert(idVector.get(i) == expectedId)
expectedId += 1
val eid = eidVector.get(i)
assert(eid > previousEId)
previousEId = eid
i += 1
}
}
reader.close()
}
allocator.close()
}

test("SPARK-41165: failures in the arrow collect path should not cause hangs") {
val instance = new SparkConnectService(false)

Expand All @@ -92,21 +165,23 @@ class SparkConnectServiceSuite extends SharedSparkSession {
.setUserContext(context)
.build()

val promise = Promise[Seq[proto.ExecutePlanResponse]]
// The observer is executed inside this thread. So
// we can perform the checks inside the observer.
instance.executePlan(
request,
new StreamObserver[proto.ExecutePlanResponse] {
private val responses = Seq.newBuilder[proto.ExecutePlanResponse]
override def onNext(v: proto.ExecutePlanResponse): Unit = {
fail("this should not receive responses")
}

override def onNext(v: proto.ExecutePlanResponse): Unit = responses += v

override def onError(throwable: Throwable): Unit = promise.failure(throwable)
override def onError(throwable: Throwable): Unit = {
assert(throwable.isInstanceOf[SparkException])
}

override def onCompleted(): Unit = promise.success(responses.result())
override def onCompleted(): Unit = {
fail("this should not complete")
}
})
intercept[SparkException] {
ThreadUtils.awaitResult(promise.future, 2.seconds)
}
}

test("Test explain mode in analyze response") {
Expand Down
1 change: 0 additions & 1 deletion python/pyspark/sql/tests/connect/test_connect_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ def test_create_global_temp_view(self):
with self.assertRaises(_MultiThreadedRendezvous):
self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1")

@unittest.skip("test_fill_na is flaky")
def test_fill_na(self):
# SPARK-41128: Test fill na
query = """
Expand Down