Skip to content
Closed
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
d363260
foreachbatch spark connect
pengzhon-db Apr 21, 2023
508928a
add streaming_worker.py
pengzhon-db May 5, 2023
6cb7d01
python proto
pengzhon-db May 5, 2023
74ad159
use same python process for one streaming query
pengzhon-db May 6, 2023
1f73c6f
wip
WweiL Jun 21, 2023
2dec59c
working
WweiL Jun 22, 2023
40bf120
latest change
WweiL Jul 11, 2023
57b80ff
resolve conflicts
WweiL Jul 23, 2023
c5c8e85
streaming function declaration pyi & rdd
WweiL Jul 23, 2023
8b9dfe5
this won't work, still throws None obj doesn't have craeteDataFrame
WweiL Jul 24, 2023
341d588
same error, this also doesn't work
WweiL Jul 24, 2023
4d8eec3
this worked
WweiL Jul 24, 2023
3a43d6c
first revision
WweiL Jul 24, 2023
38d76c0
file cleanup
WweiL Jul 24, 2023
1923840
doc update
WweiL Jul 24, 2023
8bcf605
is this breaking change also?
WweiL Jul 24, 2023
56665c3
remove doc test for now
WweiL Jul 24, 2023
cb0caea
add remove listener
WweiL Jul 24, 2023
9c4f6e6
gen proto
WweiL Jul 24, 2023
3724992
ticket update
WweiL Jul 24, 2023
2e3b8d2
before resolving merge conflict
WweiL Jul 24, 2023
4a6a15f
resolve conflict
WweiL Jul 24, 2023
0040b70
documentation to PythonStreamingQueryListener
WweiL Jul 24, 2023
d404f9f
this works on manual test but in unit test it shows No module named '…
WweiL Jul 25, 2023
a2041ec
works now
WweiL Jul 25, 2023
4a2d184
minor
WweiL Jul 25, 2023
5d04245
minor
WweiL Jul 25, 2023
8c97ecc
doesn't need to make q stateful
WweiL Jul 25, 2023
494c243
fmt
WweiL Jul 25, 2023
28d1495
why is there a unused import
WweiL Jul 25, 2023
b7cc36f
minor
WweiL Jul 25, 2023
20b87e6
try to resolve breaking change, will address other comments tmr
WweiL Jul 26, 2023
72552ed
scala client send ids
WweiL Jul 26, 2023
0640fdf
remove return NOne
WweiL Jul 26, 2023
b99301b
merge master, remove println log
WweiL Jul 26, 2023
004e181
add streamingPythonEval
WweiL Jul 26, 2023
640ab23
minor:
WweiL Jul 27, 2023
270363b
remove eval type, create two worker files
WweiL Jul 27, 2023
487bfa1
lint
WweiL Jul 27, 2023
ed5f101
retrigger tests
WweiL Jul 27, 2023
8c5569f
Merge remote-tracking branch 'spark/master' into listener-poc-newest
WweiL Jul 27, 2023
13f50bf
minor
WweiL Jul 27, 2023
3f83d07
lint
WweiL Jul 27, 2023
eb4d2b5
lint
WweiL Jul 28, 2023
95b0111
Merge remote-tracking branch 'spark/master' into listener-poc-newest
WweiL Jul 28, 2023
fb4415b
lint
WweiL Jul 28, 2023
baf791c
address comments, move worker files to sql/connect/streaming/worker
WweiL Jul 29, 2023
8c520ba
minor, remove redundant log
WweiL Jul 29, 2023
aa22c3b
add init
WweiL Jul 30, 2023
c9ffa52
add new pkg to setup.py
WweiL Jul 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,8 @@ message StreamingQueryManagerCommand {
}

message StreamingQueryListenerCommand {
bytes listener_payload = 1;
optional bytes listener_payload = 1;
optional PythonUDF python_listener_payload = 2;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2804,7 +2804,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
StreamingForeachBatchHelper.scalaForeachBatchWrapper(scalaFn, sessionHolder)

case StreamingForeachFunction.FunctionCase.FUNCTION_NOT_SET =>
throw InvalidPlanInput("Unexpected") // Unreachable
throw InvalidPlanInput("Unexpected foreachBatch function") // Unreachable
}

writer.foreachBatch(foreachBatchFn)
Expand Down Expand Up @@ -3047,18 +3047,29 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
respBuilder.setResetTerminated(true)

case StreamingQueryManagerCommand.CommandCase.ADD_LISTENER =>
val listenerPacket = Utils
.deserialize[StreamingListenerPacket](
command.getAddListener.getListenerPayload.toByteArray,
Utils.getContextOrSparkClassLoader)
val listener: StreamingQueryListener = listenerPacket.listener
.asInstanceOf[StreamingQueryListener]
val id: String = listenerPacket.id
sessionHolder.cacheListenerById(id, listener)
val listener = if (command.getAddListener.hasListenerPayload) {
val listenerPacket = Utils
.deserialize[StreamingListenerPacket](
command.getAddListener.getListenerPayload.toByteArray,
Utils.getContextOrSparkClassLoader)
val listener: StreamingQueryListener = listenerPacket.listener
.asInstanceOf[StreamingQueryListener]
val id: String = listenerPacket.id
sessionHolder.cacheListenerById(id, listener)
listener
} else {
val listener = new PythonStreamingQueryListener(
transformPythonFunction(command.getAddListener.getPythonListenerPayload),
sessionHolder,
pythonExec)
listener
}

session.streams.addListener(listener)
respBuilder.setAddListener(true)

case StreamingQueryManagerCommand.CommandCase.REMOVE_LISTENER =>
// TODO (SPARK-44516): remove listener for python client
val listenerId = Utils
.deserialize[StreamingListenerPacket](
command.getRemoveListener.getListenerPayload.toByteArray,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ package org.apache.spark.sql.connect.planner

import java.util.UUID

import org.apache.spark.api.python.PythonRDD
import org.apache.spark.api.python.SimplePythonFunction
import org.apache.spark.api.python.StreamingPythonRunner
import org.apache.spark.api.python.{PythonEvalType, PythonRDD, SimplePythonFunction, StreamingPythonRunner}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.connect.service.SessionHolder
Expand Down Expand Up @@ -89,7 +87,8 @@ object StreamingForeachBatchHelper extends Logging {
val port = SparkConnectService.localPort
val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
val runner = StreamingPythonRunner(pythonFn, connectUrl)
val (dataOut, dataIn) = runner.init(sessionHolder.sessionId)
val (dataOut, dataIn) = runner.init(
sessionHolder.sessionId, PythonEvalType.SQL_STREAMING_FOREACH_BATCH)

val foreachBatchRunnerFn: FnArgsWithId => Unit = (args: FnArgsWithId) => {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.planner

import org.apache.spark.api.python.{PythonEvalType, PythonRDD, SimplePythonFunction, StreamingPythonRunner}
import org.apache.spark.sql.connect.service.{SessionHolder, SparkConnectService}
import org.apache.spark.sql.streaming.StreamingQueryListener

class PythonStreamingQueryListener(
listener: SimplePythonFunction,
sessionHolder: SessionHolder,
pythonExec: String)
extends StreamingQueryListener {

val port = SparkConnectService.localPort
val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
val runner = StreamingPythonRunner(listener, connectUrl)

val (dataOut, _) = runner.init(
sessionHolder.sessionId, PythonEvalType.SQL_STREAMING_LISTENER)

override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = {
PythonRDD.writeUTF(event.json, dataOut)
dataOut.writeInt(0)
dataOut.flush()
}

override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = {
PythonRDD.writeUTF(event.json, dataOut)
dataOut.writeInt(1)
dataOut.flush()
}

override def onQueryIdle(event: StreamingQueryListener.QueryIdleEvent): Unit = {
PythonRDD.writeUTF(event.json, dataOut)
dataOut.writeInt(2)
dataOut.flush()
}

override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = {
PythonRDD.writeUTF(event.json, dataOut)
dataOut.writeInt(3)
dataOut.flush()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ private[spark] object PythonEvalType {
val SQL_TABLE_UDF = 300
val SQL_ARROW_TABLE_UDF = 301

val SQL_STREAMING_FOREACH_BATCH = 400
val SQL_STREAMING_LISTENER = 401
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have a separate worker file instead of using eval? event handling is async so it might conflict with other existing running Python workers. At least they would affect each other on execution time I susepct?


def toString(pythonEvalType: Int): String = pythonEvalType match {
case NON_UDF => "NON_UDF"
case SQL_BATCHED_UDF => "SQL_BATCHED_UDF"
Expand All @@ -74,6 +77,8 @@ private[spark] object PythonEvalType {
case SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE => "SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE"
case SQL_TABLE_UDF => "SQL_TABLE_UDF"
case SQL_ARROW_TABLE_UDF => "SQL_ARROW_TABLE_UDF"
case SQL_STREAMING_FOREACH_BATCH => "SQL_STREAMING_FOREACH_BATCH"
case SQL_STREAMING_LISTENER => "SQL_STREAMING_LISTENER"
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: Str
* Initializes the Python worker for streaming functions. Sets up Spark Connect session
* to be used with the functions.
*/
def init(sessionId: String): (DataOutputStream, DataInputStream) = {
def init(sessionId: String, evalType: Int): (DataOutputStream, DataInputStream) = {
log.info(s"Initializing Python runner (session: $sessionId ,pythonExec: $pythonExec")

val env = SparkEnv.get
Expand All @@ -73,6 +73,9 @@ private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: Str
// Send sessionId
PythonRDD.writeUTF(sessionId, dataOut)

// Send evalType
dataOut.writeInt(evalType)

// send the user function to python process
val command = func.command
dataOut.writeInt(command.length)
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@
SQLArrowTableUDFType,
SQLBatchedUDFType,
SQLTableUDFType,
SQLStreamingForeachBatchType,
SQLStreamingListenerType,
)

from py4j.java_gateway import JavaObject
Expand Down Expand Up @@ -162,6 +164,9 @@ class PythonEvalType:
SQL_TABLE_UDF: "SQLTableUDFType" = 300
SQL_ARROW_TABLE_UDF: "SQLArrowTableUDFType" = 301

SQL_STREAMING_FOREACH_BATCH: "SQLStreamingForeachBatchType" = 400
SQL_STREAMING_LISTENER: "SQLStreamingListenerType" = 401


def portable_hash(x: Hashable) -> int:
"""
Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/sql/_typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ SQLArrowBatchedUDFType = Literal[101]
SQLTableUDFType = Literal[300]
SQLArrowTableUDFType = Literal[301]

SQLStreamingForeachBatchType = Literal[400]
SQLStreamingListenerType = LiteralType[401]

class SupportsOpen(Protocol):
def open(self, partition_id: int, epoch_id: int) -> bool: ...

Expand Down
Loading