Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
executeManagerCmd(
_.getAddListenerBuilder
.setListenerPayload(ByteString.copyFrom(SparkSerDeUtils
.serialize(StreamingListenerPacket(id, listener)))))
.serialize(StreamingListenerPacket(id, listener))))
.setId(id))
}

/**
Expand All @@ -168,8 +169,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
val id = getIdByListener(listener)
executeManagerCmd(
_.getRemoveListenerBuilder
.setListenerPayload(ByteString.copyFrom(SparkSerDeUtils
.serialize(StreamingListenerPacket(id, listener)))))
.setId(id))
removeCachedListener(id)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ class StreamingQuerySuite extends QueryTest with SQLHelper with Logging {
spark.sql("DROP TABLE IF EXISTS my_listener_table")
}

// List listeners after adding a new listener, length should be 2.
// List listeners after adding a new listener, length should be 1.
val listeners = spark.streams.listListeners()
assert(listeners.length == 1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,8 @@ message StreamingQueryManagerCommand {

message StreamingQueryListenerCommand {
bytes listener_payload = 1;
optional PythonUDF python_listener_payload = 2;
string id = 3;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2823,7 +2823,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
StreamingForeachBatchHelper.scalaForeachBatchWrapper(scalaFn, sessionHolder)

case StreamingForeachFunction.FunctionCase.FUNCTION_NOT_SET =>
throw InvalidPlanInput("Unexpected") // Unreachable
throw InvalidPlanInput("Unexpected foreachBatch function") // Unreachable
}

writer.foreachBatch(foreachBatchFn)
Expand Down Expand Up @@ -3066,23 +3066,27 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
respBuilder.setResetTerminated(true)

case StreamingQueryManagerCommand.CommandCase.ADD_LISTENER =>
val listenerPacket = Utils
.deserialize[StreamingListenerPacket](
command.getAddListener.getListenerPayload.toByteArray,
Utils.getContextOrSparkClassLoader)
val listener: StreamingQueryListener = listenerPacket.listener
.asInstanceOf[StreamingQueryListener]
val id: String = listenerPacket.id
val listener = if (command.getAddListener.hasPythonListenerPayload) {
new PythonStreamingQueryListener(
transformPythonFunction(command.getAddListener.getPythonListenerPayload),
sessionHolder,
pythonExec)
} else {
val listenerPacket = Utils
.deserialize[StreamingListenerPacket](
command.getAddListener.getListenerPayload.toByteArray,
Utils.getContextOrSparkClassLoader)

listenerPacket.listener.asInstanceOf[StreamingQueryListener]
}

val id = command.getAddListener.getId
sessionHolder.cacheListenerById(id, listener)
session.streams.addListener(listener)
respBuilder.setAddListener(true)

case StreamingQueryManagerCommand.CommandCase.REMOVE_LISTENER =>
val listenerId = Utils
.deserialize[StreamingListenerPacket](
command.getRemoveListener.getListenerPayload.toByteArray,
Utils.getContextOrSparkClassLoader)
.id
val listenerId = command.getRemoveListener.getId
val listener: StreamingQueryListener = sessionHolder.getListenerOrThrow(listenerId)
session.streams.removeListener(listener)
sessionHolder.removeCachedListener(listenerId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ package org.apache.spark.sql.connect.planner

import java.util.UUID

import org.apache.spark.api.python.PythonRDD
import org.apache.spark.api.python.SimplePythonFunction
import org.apache.spark.api.python.StreamingPythonRunner
import org.apache.spark.api.python.{PythonRDD, SimplePythonFunction, StreamingPythonRunner}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.connect.service.SessionHolder
Expand Down Expand Up @@ -90,7 +88,10 @@ object StreamingForeachBatchHelper extends Logging {
val port = SparkConnectService.localPort
val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
val runner = StreamingPythonRunner(pythonFn, connectUrl)
val (dataOut, dataIn) = runner.init(sessionHolder.sessionId)
val (dataOut, dataIn) =
runner.init(
sessionHolder.sessionId,
"pyspark.sql.connect.streaming.worker.foreachBatch_worker")

val foreachBatchRunnerFn: FnArgsWithId => Unit = (args: FnArgsWithId) => {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.planner

import org.apache.spark.api.python.{PythonRDD, SimplePythonFunction, StreamingPythonRunner}
import org.apache.spark.sql.connect.service.{SessionHolder, SparkConnectService}
import org.apache.spark.sql.streaming.StreamingQueryListener

/**
* A helper class for handling StreamingQueryListener related functionality in Spark Connect. Each
* instance of this class starts a python process, inside which has the python handling logic.
* When new a event is received, it is serialized to json, and passed to the python process.
*/
class PythonStreamingQueryListener(
listener: SimplePythonFunction,
sessionHolder: SessionHolder,
pythonExec: String)
extends StreamingQueryListener {

val port = SparkConnectService.localPort
val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
val runner = StreamingPythonRunner(listener, connectUrl)

val (dataOut, _) =
runner.init(sessionHolder.sessionId, "pyspark.sql.connect.streaming.worker.listener_worker")

override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = {
PythonRDD.writeUTF(event.json, dataOut)
dataOut.writeInt(0)
dataOut.flush()
}

override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = {
PythonRDD.writeUTF(event.json, dataOut)
dataOut.writeInt(1)
dataOut.flush()
}

override def onQueryIdle(event: StreamingQueryListener.QueryIdleEvent): Unit = {
PythonRDD.writeUTF(event.json, dataOut)
dataOut.writeInt(2)
dataOut.flush()
}

override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = {
PythonRDD.writeUTF(event.json, dataOut)
dataOut.writeInt(3)
dataOut.flush()
}

// TODO(SPARK-44433)(SPARK-44516): Improve termination of Processes.
// Similar to foreachBatch when we need to exit the process when the query ends.
// In listener semantics, we need to exit the process when removeListener is called.
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: Str
* Initializes the Python worker for streaming functions. Sets up Spark Connect session
* to be used with the functions.
*/
def init(sessionId: String): (DataOutputStream, DataInputStream) = {
def init(sessionId: String, workerModule: String): (DataOutputStream, DataInputStream) = {
logInfo(s"Initializing Python runner (session: $sessionId ,pythonExec: $pythonExec")

val env = SparkEnv.get

val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",")
Expand All @@ -62,7 +61,7 @@ private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: Str
envVars.put("SPARK_CONNECT_LOCAL_URL", connectUrl)

val pythonWorkerFactory =
new PythonWorkerFactory(pythonExec, "pyspark.streaming_worker", envVars.asScala.toMap)
new PythonWorkerFactory(pythonExec, workerModule, envVars.asScala.toMap)
val (worker: Socket, _) = pythonWorkerFactory.createSimpleWorker()

val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
Expand Down
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,7 @@ def __hash__(self):
"pyspark.sql.tests.connect.client.test_artifact",
"pyspark.sql.tests.connect.client.test_client",
"pyspark.sql.tests.connect.streaming.test_parity_streaming",
"pyspark.sql.tests.connect.streaming.test_parity_listener",
"pyspark.sql.tests.connect.streaming.test_parity_foreach",
"pyspark.sql.tests.connect.streaming.test_parity_foreachBatch",
"pyspark.sql.tests.connect.test_parity_pandas_grouped_map_with_state",
Expand Down
Loading