diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 3b734616b213..c535cf192d01 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -160,17 +160,24 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte () } - spark.sparkContext.runJob(batches, processPartition, resultHandler) + spark.sparkContext.submitJob( + rdd = batches, + processPartition = processPartition, + partitions = Seq.range(0, numPartitions), + resultHandler = resultHandler, + resultFunc = () => ()) // The man thread will wait until 0-th partition is available, - // then send it to client and wait for next partition. + // then send it to client and wait for the next partition. var currentPartitionId = 0 while (currentPartitionId < numPartitions) { val partition = signal.synchronized { - while (!partitions.contains(currentPartitionId)) { + var result = partitions.remove(currentPartitionId) + while (result.isEmpty) { signal.wait() + result = partitions.remove(currentPartitionId) } - partitions.remove(currentPartitionId).get + result.get } partition.foreach { case (bytes, count) =>