diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala index e2f3be02ad3a..404bd1b078ba 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala @@ -75,34 +75,6 @@ abstract class StreamingQueryListener extends Serializable { def onQueryTerminated(event: QueryTerminatedEvent): Unit } -/** - * Py4J allows a pure interface so this proxy is required. - */ -private[spark] trait PythonStreamingQueryListener { - import StreamingQueryListener._ - - def onQueryStarted(event: QueryStartedEvent): Unit - - def onQueryProgress(event: QueryProgressEvent): Unit - - def onQueryIdle(event: QueryIdleEvent): Unit - - def onQueryTerminated(event: QueryTerminatedEvent): Unit -} - -private[spark] class PythonStreamingQueryListenerWrapper(listener: PythonStreamingQueryListener) - extends StreamingQueryListener { - import StreamingQueryListener._ - - def onQueryStarted(event: QueryStartedEvent): Unit = listener.onQueryStarted(event) - - def onQueryProgress(event: QueryProgressEvent): Unit = listener.onQueryProgress(event) - - override def onQueryIdle(event: QueryIdleEvent): Unit = listener.onQueryIdle(event) - - def onQueryTerminated(event: QueryTerminatedEvent): Unit = listener.onQueryTerminated(event) -} - /** * Companion object of [[StreamingQueryListener]] that defines the listener events. * @since 3.5.0 diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index f4b33ae961a2..7136476b515f 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -3097,10 +3097,14 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { case StreamingQueryManagerCommand.CommandCase.REMOVE_LISTENER => val listenerId = command.getRemoveListener.getId - val listener: StreamingQueryListener = sessionHolder.getListenerOrThrow(listenerId) - session.streams.removeListener(listener) - sessionHolder.removeCachedListener(listenerId) - respBuilder.setRemoveListener(true) + sessionHolder.getListener(listenerId) match { + case Some(listener) => + session.streams.removeListener(listener) + sessionHolder.removeCachedListener(listenerId) + respBuilder.setRemoveListener(true) + case None => + respBuilder.setRemoveListener(false) + } case StreamingQueryManagerCommand.CommandCase.LIST_LISTENERS => respBuilder.getListListenersBuilder diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala index 3b9ae483cf1b..f2fa4bf869b9 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala @@ -87,11 +87,13 @@ object StreamingForeachBatchHelper extends Logging { val port = SparkConnectService.localPort val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}" - val runner = StreamingPythonRunner(pythonFn, connectUrl) + val runner = StreamingPythonRunner( + pythonFn, + connectUrl, + sessionHolder.sessionId, + "pyspark.sql.connect.streaming.worker.foreachBatch_worker") val (dataOut, dataIn) = - runner.init( - sessionHolder.sessionId, - "pyspark.sql.connect.streaming.worker.foreachBatch_worker") + runner.init() val foreachBatchRunnerFn: FnArgsWithId => Unit = (args: FnArgsWithId) => { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala index d915bc934960..9b2a931ec4ac 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.streaming.StreamingQueryListener /** * A helper class for handling StreamingQueryListener related functionality in Spark Connect. Each * instance of this class starts a python process, inside which has the python handling logic. - * When new a event is received, it is serialized to json, and passed to the python process. + * When a new event is received, it is serialized to json, and passed to the python process. */ class PythonStreamingQueryListener( listener: SimplePythonFunction, @@ -32,12 +32,15 @@ class PythonStreamingQueryListener( pythonExec: String) extends StreamingQueryListener { - val port = SparkConnectService.localPort - val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}" - val runner = StreamingPythonRunner(listener, connectUrl) + private val port = SparkConnectService.localPort + private val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}" + private val runner = StreamingPythonRunner( + listener, + connectUrl, + sessionHolder.sessionId, + "pyspark.sql.connect.streaming.worker.listener_worker") - val (dataOut, _) = - runner.init(sessionHolder.sessionId, "pyspark.sql.connect.streaming.worker.listener_worker") + val (dataOut, _) = runner.init() override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = { PythonRDD.writeUTF(event.json, dataOut) @@ -63,7 +66,7 @@ class PythonStreamingQueryListener( dataOut.flush() } - // TODO(SPARK-44433)(SPARK-44516): Improve termination of Processes. - // Similar to foreachBatch when we need to exit the process when the query ends. - // In listener semantics, we need to exit the process when removeListener is called. + private[spark] def stopListenerProcess(): Unit = { + runner.stop() + } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 310bb9208c21..29134f0dc0de 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager import org.apache.spark.sql.connect.common.InvalidPlanInput +import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener import org.apache.spark.sql.streaming.StreamingQueryListener import org.apache.spark.util.{SystemClock} import org.apache.spark.util.Utils @@ -220,20 +221,22 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio } /** - * Returns [[StreamingQueryListener]] cached for Listener ID `id`. If it is not found, throw - * [[InvalidPlanInput]]. + * Returns [[StreamingQueryListener]] cached for Listener ID `id`. If it is not found, return + * None. */ - private[connect] def getListenerOrThrow(id: String): StreamingQueryListener = { + private[connect] def getListener(id: String): Option[StreamingQueryListener] = { Option(listenerCache.get(id)) - .getOrElse { - throw InvalidPlanInput(s"No listener with id $id is found in the session $sessionId") - } } /** - * Removes corresponding StreamingQueryListener by ID. + * Removes corresponding StreamingQueryListener by ID. Terminates the python process if it's a + * Spark Connect PythonStreamingQueryListener. */ - private[connect] def removeCachedListener(id: String): StreamingQueryListener = { + private[connect] def removeCachedListener(id: String): Unit = { + listenerCache.get(id) match { + case pyListener: PythonStreamingQueryListener => pyListener.stopListenerProcess() + case _ => // do nothing + } listenerCache.remove(id) } diff --git a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala index d4fd9485675f..f14289f984a2 100644 --- a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala @@ -29,27 +29,36 @@ import org.apache.spark.internal.config.Python.{PYTHON_AUTH_SOCKET_TIMEOUT, PYTH private[spark] object StreamingPythonRunner { - def apply(func: PythonFunction, connectUrl: String): StreamingPythonRunner = { - new StreamingPythonRunner(func, connectUrl) + def apply( + func: PythonFunction, + connectUrl: String, + sessionId: String, + workerModule: String + ): StreamingPythonRunner = { + new StreamingPythonRunner(func, connectUrl, sessionId, workerModule) } } -private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: String) - extends Logging { +private[spark] class StreamingPythonRunner( + func: PythonFunction, + connectUrl: String, + sessionId: String, + workerModule: String) extends Logging { private val conf = SparkEnv.get.conf protected val bufferSize: Int = conf.get(BUFFER_SIZE) protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT) private val envVars: java.util.Map[String, String] = func.envVars private val pythonExec: String = func.pythonExec + private var pythonWorker: Option[Socket] = None protected val pythonVer: String = func.pythonVer /** * Initializes the Python worker for streaming functions. Sets up Spark Connect session * to be used with the functions. */ - def init(sessionId: String, workerModule: String): (DataOutputStream, DataInputStream) = { - logInfo(s"Initializing Python runner (session: $sessionId ,pythonExec: $pythonExec") + def init(): (DataOutputStream, DataInputStream) = { + logInfo(s"Initializing Python runner (session: $sessionId, pythonExec: $pythonExec") val env = SparkEnv.get val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") @@ -60,9 +69,9 @@ private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: Str conf.set(PYTHON_USE_DAEMON, false) envVars.put("SPARK_CONNECT_LOCAL_URL", connectUrl) - val pythonWorkerFactory = - new PythonWorkerFactory(pythonExec, workerModule, envVars.asScala.toMap) - val (worker: Socket, _) = pythonWorkerFactory.createSimpleWorker() + val (worker, _) = env.createPythonWorker( + pythonExec, workerModule, envVars.asScala.toMap) + pythonWorker = Some(worker) val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) @@ -85,4 +94,13 @@ private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: Str (dataOut, dataIn) } + + /** + * Stops the Python worker. + */ + def stop(): Unit = { + pythonWorker.foreach { worker => + SparkEnv.get.destroyPythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker) + } + } } diff --git a/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py index 054788539f29..48a9848de400 100644 --- a/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py +++ b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py @@ -76,7 +76,9 @@ def process(df_id, batch_id): # type: ignore[no-untyped-def] # Read information about how to connect back to the JVM from the environment. java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] - (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + (sock_file, sock) = local_connect_and_auth(java_port, auth_secret) + # There could be a long time between each micro batch. + sock.settimeout(None) write_int(os.getpid(), sock_file) sock_file.flush() main(sock_file, sock_file) diff --git a/python/pyspark/sql/connect/streaming/worker/listener_worker.py b/python/pyspark/sql/connect/streaming/worker/listener_worker.py index 8eb310461b6f..7aef911426de 100644 --- a/python/pyspark/sql/connect/streaming/worker/listener_worker.py +++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py @@ -89,7 +89,9 @@ def process(listener_event_str, listener_event_type): # type: ignore[no-untyped # Read information about how to connect back to the JVM from the environment. java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] - (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + (sock_file, sock) = local_connect_and_auth(java_port, auth_secret) + # There could be a long time between each listener event. + sock.settimeout(None) write_int(os.getpid(), sock_file) sock_file.flush() main(sock_file, sock_file) diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py index 547462d4da6d..4bf58bf7807b 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py @@ -60,6 +60,10 @@ def test_listener_events(self): try: self.spark.streams.addListener(test_listener) + # This ensures the read socket on the server won't crash (i.e. because of timeout) + # when there hasn't been a new event for a long time + time.sleep(30) + df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() q = df.writeStream.format("noop").queryName("test").start() @@ -76,6 +80,9 @@ def test_listener_events(self): finally: self.spark.streams.removeListener(test_listener) + # Remove again to verify this won't throw any error + self.spark.streams.removeListener(test_listener) + if __name__ == "__main__": import unittest