diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala index 56776819dac9..37c3120a8ff4 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.connect.execution +import scala.concurrent.{ExecutionContext, Promise} import scala.jdk.CollectionConverters._ +import scala.util.Try import scala.util.control.NonFatal import com.google.protobuf.Message @@ -30,7 +32,7 @@ import org.apache.spark.sql.connect.common.ProtoUtils import org.apache.spark.sql.connect.planner.SparkConnectPlanner import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag, SparkConnectService} import org.apache.spark.sql.connect.utils.ErrorUtils -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} /** * This class launches the actual execution in an execution thread. The execution pushes the @@ -38,10 +40,12 @@ import org.apache.spark.util.Utils */ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends Logging { + private val promise: Promise[Unit] = Promise[Unit]() + // The newly created thread will inherit all InheritableThreadLocals used by Spark, // e.g. SparkContext.localProperties. If considering implementing a thread-pool, // forwarding of thread locals needs to be taken into account. - private val executionThread: Thread = new ExecutionThread() + private val executionThread: ExecutionThread = new ExecutionThread(promise) private var started: Boolean = false @@ -63,11 +67,11 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends } } - /** Joins the background execution thread after it is finished. */ - private[connect] def join(): Unit = { - // only called when the execution is completed or interrupted. - assert(completed || interrupted) - executionThread.join() + /** + * Register a callback that gets executed after completion/interruption of the execution thread. + */ + private[connect] def processOnCompletion(callback: Try[Unit] => Unit): Unit = { + promise.future.onComplete(callback)(ExecuteThreadRunner.namedExecutionContext) } /** @@ -276,10 +280,21 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends .build() } - private class ExecutionThread + private class ExecutionThread(onCompletionPromise: Promise[Unit]) extends Thread(s"SparkConnectExecuteThread_opId=${executeHolder.operationId}") { override def run(): Unit = { - execute() + try { + execute() + onCompletionPromise.success(()) + } catch { + case NonFatal(e) => + onCompletionPromise.failure(e) + } } } } + +private[connect] object ExecuteThreadRunner { + private implicit val namedExecutionContext: ExecutionContext = ExecutionContext + .fromExecutor(ThreadUtils.newDaemonSingleThreadExecutor("SparkConnectExecuteThreadCallback")) +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala index f03f81326064..3112d12bb0e6 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala @@ -117,6 +117,9 @@ private[connect] class ExecuteHolder( : mutable.ArrayBuffer[ExecuteGrpcResponseSender[proto.ExecutePlanResponse]] = new mutable.ArrayBuffer[ExecuteGrpcResponseSender[proto.ExecutePlanResponse]]() + /** For testing. Whether the async completion callback is called. */ + @volatile private[connect] var completionCallbackCalled: Boolean = false + /** * Start the execution. The execution is started in a background thread in ExecuteThreadRunner. * Responses are produced and cached in ExecuteResponseObserver. A GRPC thread consumes the @@ -238,8 +241,15 @@ private[connect] class ExecuteHolder( if (closedTimeMs.isEmpty) { // interrupt execution, if still running. runner.interrupt() - // wait for execution to finish, to make sure no more results get pushed to responseObserver - runner.join() + // Do not wait for the execution to finish, clean up resources immediately. + runner.processOnCompletion { _ => + completionCallbackCalled = true + // The execution may not immediately get interrupted, clean up any remaining resources when + // it does. + responseObserver.removeAll() + // post closed to UI + eventsManager.postClosed() + } // interrupt any attached grpcResponseSenders grpcResponseSenders.foreach(_.interrupt()) // if there were still any grpcResponseSenders, register detach time @@ -249,8 +259,6 @@ private[connect] class ExecuteHolder( } // remove all cached responses from observer responseObserver.removeAll() - // post closed to UI - eventsManager.postClosed() closedTimeMs = Some(System.currentTimeMillis()) } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index 63cebd452364..af18fca9dd21 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -31,6 +31,8 @@ import org.apache.arrow.vector.{BigIntVector, Float8Vector} import org.apache.arrow.vector.ipc.ArrowStreamReader import org.mockito.Mockito.when import org.scalatest.Tag +import org.scalatest.concurrent.Eventually +import org.scalatest.time.SpanSugar.convertIntToGrainOfTime import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.{SparkContext, SparkEnv} @@ -884,8 +886,11 @@ class SparkConnectServiceSuite assert(executeHolder.eventsManager.hasError.isDefined) } def onCompleted(producedRowCount: Option[Long] = None): Unit = { - assert(executeHolder.eventsManager.status == ExecuteStatus.Closed) assert(executeHolder.eventsManager.getProducedRowCount == producedRowCount) + // The eventsManager is closed asynchronously + Eventually.eventually(timeout(1.seconds)) { + assert(executeHolder.eventsManager.status == ExecuteStatus.Closed) + } } def onCanceled(): Unit = { assert(executeHolder.eventsManager.hasCanceled.contains(true)) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala index 33560cd53f6b..cb0bd8f771eb 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala @@ -91,6 +91,29 @@ class SparkConnectServiceE2ESuite extends SparkConnectServerTest { } } + test("Async cleanup callback gets called after the execution is closed") { + withClient(UUID.randomUUID().toString, defaultUserId) { client => + val query1 = client.execute(buildPlan(BIG_ENOUGH_QUERY)) + // just creating the iterator is lazy, trigger query1 and query2 to be sent. + query1.hasNext + Eventually.eventually(timeout(eventuallyTimeout)) { + assert(SparkConnectService.executionManager.listExecuteHolders.length == 1) + } + val executeHolder1 = SparkConnectService.executionManager.listExecuteHolders.head + // Close session + client.releaseSession() + // Check that queries get cancelled + Eventually.eventually(timeout(eventuallyTimeout)) { + assert(SparkConnectService.executionManager.listExecuteHolders.length == 0) + // SparkConnectService.sessionManager. + } + // Check the async execute cleanup get called + Eventually.eventually(timeout(eventuallyTimeout)) { + assert(executeHolder1.completionCallbackCalled) + } + } + } + private def testReleaseSessionTwoSessions( sessionIdA: String, userIdA: String,