Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

package org.apache.spark.sql.connect.execution

import scala.concurrent.{ExecutionContext, Promise}
import scala.jdk.CollectionConverters._
import scala.util.Try
import scala.util.control.NonFatal

import com.google.protobuf.Message
Expand All @@ -30,18 +32,20 @@ import org.apache.spark.sql.connect.common.ProtoUtils
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag, SparkConnectService}
import org.apache.spark.sql.connect.utils.ErrorUtils
import org.apache.spark.util.Utils
import org.apache.spark.util.{ThreadUtils, Utils}

/**
* This class launches the actual execution in an execution thread. The execution pushes the
* responses to a ExecuteResponseObserver in executeHolder.
*/
private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends Logging {

private val promise: Promise[Unit] = Promise[Unit]()

// The newly created thread will inherit all InheritableThreadLocals used by Spark,
// e.g. SparkContext.localProperties. If considering implementing a thread-pool,
// forwarding of thread locals needs to be taken into account.
private val executionThread: Thread = new ExecutionThread()
private val executionThread: ExecutionThread = new ExecutionThread(promise)

private var started: Boolean = false

Expand All @@ -63,11 +67,11 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends
}
}

/** Joins the background execution thread after it is finished. */
private[connect] def join(): Unit = {
// only called when the execution is completed or interrupted.
assert(completed || interrupted)
executionThread.join()
/**
* Register a callback that gets executed after completion/interruption of the execution thread.
*/
private[connect] def processOnCompletion(callback: Try[Unit] => Unit): Unit = {
promise.future.onComplete(callback)(ExecuteThreadRunner.namedExecutionContext)
}

/**
Expand Down Expand Up @@ -276,10 +280,21 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends
.build()
}

private class ExecutionThread
private class ExecutionThread(onCompletionPromise: Promise[Unit])
extends Thread(s"SparkConnectExecuteThread_opId=${executeHolder.operationId}") {
override def run(): Unit = {
execute()
try {
execute()
onCompletionPromise.success(())
} catch {
case NonFatal(e) =>
onCompletionPromise.failure(e)
}
}
}
}

private[connect] object ExecuteThreadRunner {
private implicit val namedExecutionContext: ExecutionContext = ExecutionContext
.fromExecutor(ThreadUtils.newDaemonSingleThreadExecutor("SparkConnectExecuteThreadCallback"))
}
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ private[connect] class ExecuteHolder(
: mutable.ArrayBuffer[ExecuteGrpcResponseSender[proto.ExecutePlanResponse]] =
new mutable.ArrayBuffer[ExecuteGrpcResponseSender[proto.ExecutePlanResponse]]()

/** For testing. Whether the async completion callback is called. */
@volatile private[connect] var completionCallbackCalled: Boolean = false

/**
* Start the execution. The execution is started in a background thread in ExecuteThreadRunner.
* Responses are produced and cached in ExecuteResponseObserver. A GRPC thread consumes the
Expand Down Expand Up @@ -238,8 +241,15 @@ private[connect] class ExecuteHolder(
if (closedTimeMs.isEmpty) {
// interrupt execution, if still running.
runner.interrupt()
// wait for execution to finish, to make sure no more results get pushed to responseObserver
runner.join()
// Do not wait for the execution to finish, clean up resources immediately.
runner.processOnCompletion { _ =>
completionCallbackCalled = true
// The execution may not immediately get interrupted, clean up any remaining resources when
// it does.
responseObserver.removeAll()
// post closed to UI
eventsManager.postClosed()
}
// interrupt any attached grpcResponseSenders
grpcResponseSenders.foreach(_.interrupt())
// if there were still any grpcResponseSenders, register detach time
Expand All @@ -249,8 +259,6 @@ private[connect] class ExecuteHolder(
}
// remove all cached responses from observer
responseObserver.removeAll()
// post closed to UI
eventsManager.postClosed()
closedTimeMs = Some(System.currentTimeMillis())
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import org.apache.arrow.vector.{BigIntVector, Float8Vector}
import org.apache.arrow.vector.ipc.ArrowStreamReader
import org.mockito.Mockito.when
import org.scalatest.Tag
import org.scalatest.concurrent.Eventually
import org.scalatest.time.SpanSugar.convertIntToGrainOfTime
import org.scalatestplus.mockito.MockitoSugar

import org.apache.spark.{SparkContext, SparkEnv}
Expand Down Expand Up @@ -884,8 +886,11 @@ class SparkConnectServiceSuite
assert(executeHolder.eventsManager.hasError.isDefined)
}
def onCompleted(producedRowCount: Option[Long] = None): Unit = {
assert(executeHolder.eventsManager.status == ExecuteStatus.Closed)
assert(executeHolder.eventsManager.getProducedRowCount == producedRowCount)
// The eventsManager is closed asynchronously
Eventually.eventually(timeout(1.seconds)) {
assert(executeHolder.eventsManager.status == ExecuteStatus.Closed)
}
}
def onCanceled(): Unit = {
assert(executeHolder.eventsManager.hasCanceled.contains(true))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,29 @@ class SparkConnectServiceE2ESuite extends SparkConnectServerTest {
}
}

test("Async cleanup callback gets called after the execution is closed") {
withClient(UUID.randomUUID().toString, defaultUserId) { client =>
val query1 = client.execute(buildPlan(BIG_ENOUGH_QUERY))
// just creating the iterator is lazy, trigger query1 and query2 to be sent.
query1.hasNext
Eventually.eventually(timeout(eventuallyTimeout)) {
assert(SparkConnectService.executionManager.listExecuteHolders.length == 1)
}
val executeHolder1 = SparkConnectService.executionManager.listExecuteHolders.head
// Close session
client.releaseSession()
// Check that queries get cancelled
Eventually.eventually(timeout(eventuallyTimeout)) {
assert(SparkConnectService.executionManager.listExecuteHolders.length == 0)
// SparkConnectService.sessionManager.
}
// Check the async execute cleanup get called
Eventually.eventually(timeout(eventuallyTimeout)) {
assert(executeHolder1.completionCallbackCalled)
}
}
}

private def testReleaseSessionTwoSessions(
sessionIdA: String,
userIdA: String,
Expand Down