Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,34 +75,6 @@ abstract class StreamingQueryListener extends Serializable {
def onQueryTerminated(event: QueryTerminatedEvent): Unit
}

/**
* Py4J allows a pure interface so this proxy is required.
*/
private[spark] trait PythonStreamingQueryListener {
import StreamingQueryListener._

def onQueryStarted(event: QueryStartedEvent): Unit

def onQueryProgress(event: QueryProgressEvent): Unit

def onQueryIdle(event: QueryIdleEvent): Unit

def onQueryTerminated(event: QueryTerminatedEvent): Unit
}

private[spark] class PythonStreamingQueryListenerWrapper(listener: PythonStreamingQueryListener)
extends StreamingQueryListener {
import StreamingQueryListener._

def onQueryStarted(event: QueryStartedEvent): Unit = listener.onQueryStarted(event)

def onQueryProgress(event: QueryProgressEvent): Unit = listener.onQueryProgress(event)

override def onQueryIdle(event: QueryIdleEvent): Unit = listener.onQueryIdle(event)

def onQueryTerminated(event: QueryTerminatedEvent): Unit = listener.onQueryTerminated(event)
}

/**
* Companion object of [[StreamingQueryListener]] that defines the listener events.
* @since 3.5.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3097,10 +3097,14 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {

case StreamingQueryManagerCommand.CommandCase.REMOVE_LISTENER =>
val listenerId = command.getRemoveListener.getId
val listener: StreamingQueryListener = sessionHolder.getListenerOrThrow(listenerId)
session.streams.removeListener(listener)
sessionHolder.removeCachedListener(listenerId)
respBuilder.setRemoveListener(true)
sessionHolder.getListener(listenerId) match {
case Some(listener) =>
session.streams.removeListener(listener)
sessionHolder.removeCachedListener(listenerId)
respBuilder.setRemoveListener(true)
case None =>
respBuilder.setRemoveListener(false)
}

case StreamingQueryManagerCommand.CommandCase.LIST_LISTENERS =>
respBuilder.getListListenersBuilder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,13 @@ object StreamingForeachBatchHelper extends Logging {

val port = SparkConnectService.localPort
val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
val runner = StreamingPythonRunner(pythonFn, connectUrl)
val runner = StreamingPythonRunner(
pythonFn,
connectUrl,
sessionHolder.sessionId,
"pyspark.sql.connect.streaming.worker.foreachBatch_worker")
val (dataOut, dataIn) =
runner.init(
sessionHolder.sessionId,
"pyspark.sql.connect.streaming.worker.foreachBatch_worker")
runner.init()

val foreachBatchRunnerFn: FnArgsWithId => Unit = (args: FnArgsWithId) => {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,23 @@ import org.apache.spark.sql.streaming.StreamingQueryListener
/**
* A helper class for handling StreamingQueryListener related functionality in Spark Connect. Each
* instance of this class starts a python process, inside which has the python handling logic.
* When new a event is received, it is serialized to json, and passed to the python process.
* When a new event is received, it is serialized to json, and passed to the python process.
*/
class PythonStreamingQueryListener(
listener: SimplePythonFunction,
sessionHolder: SessionHolder,
pythonExec: String)
extends StreamingQueryListener {

val port = SparkConnectService.localPort
val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
val runner = StreamingPythonRunner(listener, connectUrl)
private val port = SparkConnectService.localPort
private val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
private val runner = StreamingPythonRunner(
listener,
connectUrl,
sessionHolder.sessionId,
"pyspark.sql.connect.streaming.worker.listener_worker")

val (dataOut, _) =
runner.init(sessionHolder.sessionId, "pyspark.sql.connect.streaming.worker.listener_worker")
val (dataOut, _) = runner.init()

override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = {
PythonRDD.writeUTF(event.json, dataOut)
Expand All @@ -63,7 +66,7 @@ class PythonStreamingQueryListener(
dataOut.flush()
}

// TODO(SPARK-44433)(SPARK-44516): Improve termination of Processes.
// Similar to foreachBatch when we need to exit the process when the query ends.
// In listener semantics, we need to exit the process when removeListener is called.
private[spark] def stopListenerProcess(): Unit = {
runner.stop()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager
import org.apache.spark.sql.connect.common.InvalidPlanInput
import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.util.{SystemClock}
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -220,20 +221,22 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
}

/**
* Returns [[StreamingQueryListener]] cached for Listener ID `id`. If it is not found, throw
* [[InvalidPlanInput]].
* Returns [[StreamingQueryListener]] cached for Listener ID `id`. If it is not found, return
* None.
*/
private[connect] def getListenerOrThrow(id: String): StreamingQueryListener = {
private[connect] def getListener(id: String): Option[StreamingQueryListener] = {
Option(listenerCache.get(id))
.getOrElse {
throw InvalidPlanInput(s"No listener with id $id is found in the session $sessionId")
}
}

/**
* Removes corresponding StreamingQueryListener by ID.
* Removes corresponding StreamingQueryListener by ID. Terminates the python process if it's a
* Spark Connect PythonStreamingQueryListener.
*/
private[connect] def removeCachedListener(id: String): StreamingQueryListener = {
private[connect] def removeCachedListener(id: String): Unit = {
listenerCache.get(id) match {
case pyListener: PythonStreamingQueryListener => pyListener.stopListenerProcess()
case _ => // do nothing
}
listenerCache.remove(id)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,36 @@ import org.apache.spark.internal.config.Python.{PYTHON_AUTH_SOCKET_TIMEOUT, PYTH


private[spark] object StreamingPythonRunner {
def apply(func: PythonFunction, connectUrl: String): StreamingPythonRunner = {
new StreamingPythonRunner(func, connectUrl)
def apply(
func: PythonFunction,
connectUrl: String,
sessionId: String,
workerModule: String
): StreamingPythonRunner = {
new StreamingPythonRunner(func, connectUrl, sessionId, workerModule)
}
}

private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: String)
extends Logging {
private[spark] class StreamingPythonRunner(
func: PythonFunction,
connectUrl: String,
sessionId: String,
workerModule: String) extends Logging {
private val conf = SparkEnv.get.conf
protected val bufferSize: Int = conf.get(BUFFER_SIZE)
protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)

private val envVars: java.util.Map[String, String] = func.envVars
private val pythonExec: String = func.pythonExec
private var pythonWorker: Option[Socket] = None
protected val pythonVer: String = func.pythonVer

/**
* Initializes the Python worker for streaming functions. Sets up Spark Connect session
* to be used with the functions.
*/
def init(sessionId: String, workerModule: String): (DataOutputStream, DataInputStream) = {
logInfo(s"Initializing Python runner (session: $sessionId ,pythonExec: $pythonExec")
def init(): (DataOutputStream, DataInputStream) = {
logInfo(s"Initializing Python runner (session: $sessionId, pythonExec: $pythonExec")
val env = SparkEnv.get

val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",")
Expand All @@ -60,9 +69,9 @@ private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: Str
conf.set(PYTHON_USE_DAEMON, false)
Copy link
Member

@ueshin ueshin Aug 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not updated in this PR, but should we set this back to the original value after creating the Python worker?
As the conf is visible from other part in the Driver, it could affect the behavior.

It can be done in a separate PR.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, do we need to need to change this at all? It might be simpler to keep this unchanged.

envVars.put("SPARK_CONNECT_LOCAL_URL", connectUrl)

val pythonWorkerFactory =
new PythonWorkerFactory(pythonExec, workerModule, envVars.asScala.toMap)
val (worker: Socket, _) = pythonWorkerFactory.createSimpleWorker()
val (worker, _) = env.createPythonWorker(
pythonExec, workerModule, envVars.asScala.toMap)
pythonWorker = Some(worker)

val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
Expand All @@ -85,4 +94,13 @@ private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: Str

(dataOut, dataIn)
}

/**
* Stops the Python worker.
*/
def stop(): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: please add documentation since this is a public function

pythonWorker.foreach { worker =>
SparkEnv.get.destroyPythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def process(df_id, batch_id): # type: ignore[no-untyped-def]
# Read information about how to connect back to the JVM from the environment.
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
(sock_file, _) = local_connect_and_auth(java_port, auth_secret)
(sock_file, sock) = local_connect_and_auth(java_port, auth_secret)
# There could be a long time between each micro batch.
sock.settimeout(None)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def process(listener_event_str, listener_event_type): # type: ignore[no-untyped
# Read information about how to connect back to the JVM from the environment.
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
(sock_file, _) = local_connect_and_auth(java_port, auth_secret)
(sock_file, sock) = local_connect_and_auth(java_port, auth_secret)
# There could be a long time between each listener event.
sock.settimeout(None)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def test_listener_events(self):
try:
self.spark.streams.addListener(test_listener)

# This ensures the read socket on the server won't crash (i.e. because of timeout)
# when there hasn't been a new event for a long time
time.sleep(30)

df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load()
q = df.writeStream.format("noop").queryName("test").start()

Expand All @@ -76,6 +80,9 @@ def test_listener_events(self):
finally:
self.spark.streams.removeListener(test_listener)

# Remove again to verify this won't throw any error
self.spark.streams.removeListener(test_listener)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this test ensure listener worker is removed?
Another PR #42385 broke stop() method, but it didn't cause any failure.
(I added a comment about the breaking change: https://github.com/apache/spark/pull/42385/files#r1295429496)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That test is not to ensure the worker is removed, it is to ensure no error will be thrown when removeListener is called twice on the same listener



if __name__ == "__main__":
import unittest
Expand Down