Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
d363260
foreachbatch spark connect
pengzhon-db Apr 21, 2023
508928a
add streaming_worker.py
pengzhon-db May 5, 2023
6cb7d01
python proto
pengzhon-db May 5, 2023
74ad159
use same python process for one streaming query
pengzhon-db May 6, 2023
1f73c6f
wip
WweiL Jun 21, 2023
2dec59c
working
WweiL Jun 22, 2023
40bf120
latest change
WweiL Jul 11, 2023
57b80ff
resolve conflicts
WweiL Jul 23, 2023
c5c8e85
streaming function declaration pyi & rdd
WweiL Jul 23, 2023
8b9dfe5
this won't work, still throws None obj doesn't have craeteDataFrame
WweiL Jul 24, 2023
341d588
same error, this also doesn't work
WweiL Jul 24, 2023
4d8eec3
this worked
WweiL Jul 24, 2023
3a43d6c
first revision
WweiL Jul 24, 2023
38d76c0
file cleanup
WweiL Jul 24, 2023
1923840
doc update
WweiL Jul 24, 2023
8bcf605
is this breaking change also?
WweiL Jul 24, 2023
56665c3
remove doc test for now
WweiL Jul 24, 2023
cb0caea
add remove listener
WweiL Jul 24, 2023
9c4f6e6
gen proto
WweiL Jul 24, 2023
3724992
ticket update
WweiL Jul 24, 2023
2e3b8d2
before resolving merge conflict
WweiL Jul 24, 2023
4a6a15f
resolve conflict
WweiL Jul 24, 2023
0040b70
documentation to PythonStreamingQueryListener
WweiL Jul 24, 2023
d404f9f
this works on manual test but in unit test it shows No module named '…
WweiL Jul 25, 2023
a2041ec
works now
WweiL Jul 25, 2023
4a2d184
minor
WweiL Jul 25, 2023
5d04245
minor
WweiL Jul 25, 2023
8c97ecc
doesn't need to make q stateful
WweiL Jul 25, 2023
494c243
fmt
WweiL Jul 25, 2023
28d1495
why is there a unused import
WweiL Jul 25, 2023
b7cc36f
minor
WweiL Jul 25, 2023
20b87e6
try to resolve breaking change, will address other comments tmr
WweiL Jul 26, 2023
72552ed
scala client send ids
WweiL Jul 26, 2023
0640fdf
remove return NOne
WweiL Jul 26, 2023
b99301b
merge master, remove println log
WweiL Jul 26, 2023
004e181
add streamingPythonEval
WweiL Jul 26, 2023
640ab23
minor:
WweiL Jul 27, 2023
270363b
remove eval type, create two worker files
WweiL Jul 27, 2023
487bfa1
lint
WweiL Jul 27, 2023
ed5f101
retrigger tests
WweiL Jul 27, 2023
8c5569f
Merge remote-tracking branch 'spark/master' into listener-poc-newest
WweiL Jul 27, 2023
13f50bf
minor
WweiL Jul 27, 2023
3f83d07
lint
WweiL Jul 27, 2023
eb4d2b5
lint
WweiL Jul 28, 2023
95b0111
Merge remote-tracking branch 'spark/master' into listener-poc-newest
WweiL Jul 28, 2023
fb4415b
lint
WweiL Jul 28, 2023
baf791c
address comments, move worker files to sql/connect/streaming/worker
WweiL Jul 29, 2023
8c520ba
minor, remove redundant log
WweiL Jul 29, 2023
aa22c3b
add init
WweiL Jul 30, 2023
c9ffa52
add new pkg to setup.py
WweiL Jul 31, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
executeManagerCmd(
_.getAddListenerBuilder
.setListenerPayload(ByteString.copyFrom(SparkSerDeUtils
.serialize(StreamingListenerPacket(id, listener)))))
.serialize(StreamingListenerPacket(id, listener))))
.setId(id))
}

/**
Expand All @@ -168,8 +169,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
val id = getIdByListener(listener)
executeManagerCmd(
_.getRemoveListenerBuilder
.setListenerPayload(ByteString.copyFrom(SparkSerDeUtils
.serialize(StreamingListenerPacket(id, listener)))))
.setId(id))
removeCachedListener(id)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ class StreamingQuerySuite extends QueryTest with SQLHelper with Logging {
spark.sql("DROP TABLE IF EXISTS my_listener_table")
}

// List listeners after adding a new listener, length should be 2.
// List listeners after adding a new listener, length should be 1.
val listeners = spark.streams.listListeners()
assert(listeners.length == 1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,8 @@ message StreamingQueryManagerCommand {

message StreamingQueryListenerCommand {
bytes listener_payload = 1;
optional PythonUDF python_listener_payload = 2;
string id = 3;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2823,7 +2823,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
StreamingForeachBatchHelper.scalaForeachBatchWrapper(scalaFn, sessionHolder)

case StreamingForeachFunction.FunctionCase.FUNCTION_NOT_SET =>
throw InvalidPlanInput("Unexpected") // Unreachable
throw InvalidPlanInput("Unexpected foreachBatch function") // Unreachable
}

writer.foreachBatch(foreachBatchFn)
Expand Down Expand Up @@ -3066,23 +3066,27 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
respBuilder.setResetTerminated(true)

case StreamingQueryManagerCommand.CommandCase.ADD_LISTENER =>
val listenerPacket = Utils
.deserialize[StreamingListenerPacket](
command.getAddListener.getListenerPayload.toByteArray,
Utils.getContextOrSparkClassLoader)
val listener: StreamingQueryListener = listenerPacket.listener
.asInstanceOf[StreamingQueryListener]
val id: String = listenerPacket.id
val listener = if (command.getAddListener.hasPythonListenerPayload) {
new PythonStreamingQueryListener(
transformPythonFunction(command.getAddListener.getPythonListenerPayload),
sessionHolder,
pythonExec)
} else {
val listenerPacket = Utils
.deserialize[StreamingListenerPacket](
command.getAddListener.getListenerPayload.toByteArray,
Utils.getContextOrSparkClassLoader)

listenerPacket.listener.asInstanceOf[StreamingQueryListener]
}

val id = command.getAddListener.getId
sessionHolder.cacheListenerById(id, listener)
session.streams.addListener(listener)
respBuilder.setAddListener(true)

case StreamingQueryManagerCommand.CommandCase.REMOVE_LISTENER =>
val listenerId = Utils
.deserialize[StreamingListenerPacket](
command.getRemoveListener.getListenerPayload.toByteArray,
Utils.getContextOrSparkClassLoader)
.id
val listenerId = command.getRemoveListener.getId
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing this change!

val listener: StreamingQueryListener = sessionHolder.getListenerOrThrow(listenerId)
session.streams.removeListener(listener)
sessionHolder.removeCachedListener(listenerId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ package org.apache.spark.sql.connect.planner

import java.util.UUID

import org.apache.spark.api.python.PythonRDD
import org.apache.spark.api.python.SimplePythonFunction
import org.apache.spark.api.python.StreamingPythonRunner
import org.apache.spark.api.python.{PythonRDD, SimplePythonFunction, StreamingPythonRunner}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.connect.service.SessionHolder
Expand Down Expand Up @@ -90,7 +88,10 @@ object StreamingForeachBatchHelper extends Logging {
val port = SparkConnectService.localPort
val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
val runner = StreamingPythonRunner(pythonFn, connectUrl)
val (dataOut, dataIn) = runner.init(sessionHolder.sessionId)
val (dataOut, dataIn) =
runner.init(
sessionHolder.sessionId,
"pyspark.sql.connect.streaming.worker.foreachBatch_worker")

val foreachBatchRunnerFn: FnArgsWithId => Unit = (args: FnArgsWithId) => {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.planner

import org.apache.spark.api.python.{PythonRDD, SimplePythonFunction, StreamingPythonRunner}
import org.apache.spark.sql.connect.service.{SessionHolder, SparkConnectService}
import org.apache.spark.sql.streaming.StreamingQueryListener

/**
* A helper class for handling StreamingQueryListener related functionality in Spark Connect. Each
* instance of this class starts a python process, inside which has the python handling logic.
* When new a event is received, it is serialized to json, and passed to the python process.
*/
class PythonStreamingQueryListener(
listener: SimplePythonFunction,
sessionHolder: SessionHolder,
pythonExec: String)
extends StreamingQueryListener {

val port = SparkConnectService.localPort
val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
val runner = StreamingPythonRunner(listener, connectUrl)

val (dataOut, _) =
runner.init(sessionHolder.sessionId, "pyspark.sql.connect.streaming.worker.listener_worker")

override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = {
PythonRDD.writeUTF(event.json, dataOut)
dataOut.writeInt(0)
dataOut.flush()
}

override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = {
PythonRDD.writeUTF(event.json, dataOut)
dataOut.writeInt(1)
dataOut.flush()
}

override def onQueryIdle(event: StreamingQueryListener.QueryIdleEvent): Unit = {
PythonRDD.writeUTF(event.json, dataOut)
dataOut.writeInt(2)
dataOut.flush()
}

override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = {
PythonRDD.writeUTF(event.json, dataOut)
dataOut.writeInt(3)
dataOut.flush()
}

// TODO(SPARK-44433)(SPARK-44516): Improve termination of Processes.
// Similar to foreachBatch when we need to exit the process when the query ends.
// In listener semantics, we need to exit the process when removeListener is called.
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: Str
* Initializes the Python worker for streaming functions. Sets up Spark Connect session
* to be used with the functions.
*/
def init(sessionId: String): (DataOutputStream, DataInputStream) = {
def init(sessionId: String, workerModule: String): (DataOutputStream, DataInputStream) = {
logInfo(s"Initializing Python runner (session: $sessionId ,pythonExec: $pythonExec")

val env = SparkEnv.get

val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",")
Expand All @@ -62,7 +61,7 @@ private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: Str
envVars.put("SPARK_CONNECT_LOCAL_URL", connectUrl)

val pythonWorkerFactory =
new PythonWorkerFactory(pythonExec, "pyspark.streaming_worker", envVars.asScala.toMap)
new PythonWorkerFactory(pythonExec, workerModule, envVars.asScala.toMap)
val (worker: Socket, _) = pythonWorkerFactory.createSimpleWorker()

val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
Expand Down
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,7 @@ def __hash__(self):
"pyspark.sql.tests.connect.client.test_artifact",
"pyspark.sql.tests.connect.client.test_client",
"pyspark.sql.tests.connect.streaming.test_parity_streaming",
"pyspark.sql.tests.connect.streaming.test_parity_listener",
"pyspark.sql.tests.connect.streaming.test_parity_foreach",
"pyspark.sql.tests.connect.streaming.test_parity_foreachBatch",
"pyspark.sql.tests.connect.test_parity_pandas_grouped_map_with_state",
Expand Down
Loading