Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,8 @@ message StreamingQueryManagerCommand {
AwaitAnyTerminationCommand await_any_termination = 3;
// resetTerminated() API.
bool reset_terminated = 4;
// addListener() API.
AddStreamingQueryListenerCommand add_listener = 5;
}

message AwaitAnyTerminationCommand {
Expand Down Expand Up @@ -377,6 +379,14 @@ message StreamingQueryManagerCommandResult {
}
}

// TODO: maybe serialize the whole class?
message AddStreamingQueryListenerCommand {
bytes on_query_started = 1;
bytes on_query_progress = 2;
bytes on_query_terminated = 3;
// TODO: deserialize in python or in scala?
}

// Command to get the output of 'SparkContext.resources'
message GetResourcesCommand { }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2646,6 +2646,11 @@ class SparkConnectPlanner(val session: SparkSession) {
session.streams.resetTerminated()
respBuilder.setResetTerminated(true)

case StreamingQueryManagerCommand.CommandCase.ADD_LISTENER =>
val listener =
new PythonStreamingQueryListener(command.getAddListener, sessionId, pythonExec)
session.streams.addListener(listener)

case StreamingQueryManagerCommand.CommandCase.COMMAND_NOT_SET =>
throw new IllegalArgumentException("Missing command in StreamingQueryManagerCommand")
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect.planner

import java.util.Base64

import scala.io.Source

import org.apache.spark.api.python.PythonUtils
import org.apache.spark.connect.proto
import org.apache.spark.sql.streaming.StreamingQueryListener

class PythonStreamingQueryListener(
listener: proto.AddStreamingQueryListenerCommand,
sessionId: String,
pythonExec: String)
extends StreamingQueryListener {
// Start a process to run foreachbatch python func
// TODO: Reuse some functions from PythonRunner.scala
// TODO: Handle process better: reuse process; release process; monitor process
// TODO(wei) reuse process
// val envVars = udf.func.envVars.asScala.toMap

val pb = new ProcessBuilder()
val pbEnv = pb.environment()
val pythonPath = PythonUtils.mergePythonPaths(
PythonUtils.sparkPythonPath,
// envVars.getOrElse("PYTHONPATH", ""),
sys.env.getOrElse("PYTHONPATH", ""))
pbEnv.put("PYTHONPATH", pythonPath)
// pbEnv.putAll(envVars.asJava)

pb.command(pythonExec)

// Encode serialized func as string so that it can be passed into the process through
// arguments
val onQueryStartedBytes = listener.getOnQueryStarted.toByteArray
val onQueryStartedStr = Base64.getEncoder().encodeToString(onQueryStartedBytes)

// Output for debug for now.
// TODO: redirect the output stream
// TODO: handle error

// TODO(Wei): serialize and deserialize events

private def toJSON(event: StreamingQueryListener.QueryStartedEvent): String =
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here

s"""
|{
| "id": "${event.id}",
| "runId": "${event.runId}",
| "name": "${event.name}",
| "timestamp": "${event.timestamp}"
|}
""".stripMargin

override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = {
val eventJson = toJSON(event)
val pythonScript = s"""
|print('###### Start running onQueryStarted ######')
|from pyspark.sql import SparkSession
|from pyspark.serializers import CloudPickleSerializer
|from pyspark.sql.connect.streaming.listener import (
| StreamingQueryListener,
| QueryStartedEvent
|)
|from pyspark.sql.streaming.listener import (
| QueryProgressEvent,
| QueryTerminatedEvent,
| QueryIdleEvent
|)
|import sys
|import base64
|import json
|
|startEvent = QueryStartedEvent.fromJson(json.loads('''$eventJson'''))
|sessionId = '$sessionId'
|sparkConnectSession = SparkSession.builder.remote("sc://localhost:15002").getOrCreate()
|sparkConnectSession._client._session_id = sessionId
|
|bytes = base64.b64decode('$onQueryStartedStr')
|func = CloudPickleSerializer().loads(bytes)
|# forEachBatchFunc = unpickledCode[0]
|func(startEvent)
|exit()
""".stripMargin
pb.command(pythonExec, "-c", pythonScript)
val process = pb.start()
// Output for debug for now.
// TODO: redirect the output stream
// TODO: handle error
// TODO(WEI): python ver?
val is = process.getInputStream()
// scalastyle:off println
val out = Source.fromInputStream(is).mkString
println(s"##### Python out for query start event is: out=$out")

val es = process.getErrorStream
val errorOut = Source.fromInputStream(es).mkString
println(s"##### Python error for query start event is: error=$errorOut")

val exitCode = process.waitFor()
println(s"##### End processing query start event exitCode=$exitCode")
// scalastyle:on println
}

override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = {}

override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = {}
}
52 changes: 34 additions & 18 deletions python/pyspark/sql/connect/proto/commands_pb2.py

Large diffs are not rendered by default.

44 changes: 43 additions & 1 deletion python/pyspark/sql/connect/proto/commands_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,7 @@ class StreamingQueryManagerCommand(google.protobuf.message.Message):
GET_QUERY_FIELD_NUMBER: builtins.int
AWAIT_ANY_TERMINATION_FIELD_NUMBER: builtins.int
RESET_TERMINATED_FIELD_NUMBER: builtins.int
ADD_LISTENER_FIELD_NUMBER: builtins.int
active: builtins.bool
"""active() API, returns a list of active queries."""
get_query: builtins.str
Expand All @@ -1334,6 +1335,9 @@ class StreamingQueryManagerCommand(google.protobuf.message.Message):
"""awaitAnyTermination() API, wait until any query terminates or timeout."""
reset_terminated: builtins.bool
"""resetTerminated() API."""
@property
def add_listener(self) -> global___AddStreamingQueryListenerCommand:
"""addListener() API."""
def __init__(
self,
*,
Expand All @@ -1342,12 +1346,15 @@ class StreamingQueryManagerCommand(google.protobuf.message.Message):
await_any_termination: global___StreamingQueryManagerCommand.AwaitAnyTerminationCommand
| None = ...,
reset_terminated: builtins.bool = ...,
add_listener: global___AddStreamingQueryListenerCommand | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"active",
b"active",
"add_listener",
b"add_listener",
"await_any_termination",
b"await_any_termination",
"command",
Expand All @@ -1363,6 +1370,8 @@ class StreamingQueryManagerCommand(google.protobuf.message.Message):
field_name: typing_extensions.Literal[
"active",
b"active",
"add_listener",
b"add_listener",
"await_any_termination",
b"await_any_termination",
"command",
Expand All @@ -1376,7 +1385,7 @@ class StreamingQueryManagerCommand(google.protobuf.message.Message):
def WhichOneof(
self, oneof_group: typing_extensions.Literal["command", b"command"]
) -> typing_extensions.Literal[
"active", "get_query", "await_any_termination", "reset_terminated"
"active", "get_query", "await_any_termination", "reset_terminated", "add_listener"
] | None: ...

global___StreamingQueryManagerCommand = StreamingQueryManagerCommand
Expand Down Expand Up @@ -1510,6 +1519,39 @@ class StreamingQueryManagerCommandResult(google.protobuf.message.Message):

global___StreamingQueryManagerCommandResult = StreamingQueryManagerCommandResult

class AddStreamingQueryListenerCommand(google.protobuf.message.Message):
"""TODO: maybe serialize the whole class?"""

DESCRIPTOR: google.protobuf.descriptor.Descriptor

ON_QUERY_STARTED_FIELD_NUMBER: builtins.int
ON_QUERY_PROGRESS_FIELD_NUMBER: builtins.int
ON_QUERY_TERMINATED_FIELD_NUMBER: builtins.int
on_query_started: builtins.bytes
on_query_progress: builtins.bytes
on_query_terminated: builtins.bytes
"""TODO: deserialize in python or in scala?"""
def __init__(
self,
*,
on_query_started: builtins.bytes = ...,
on_query_progress: builtins.bytes = ...,
on_query_terminated: builtins.bytes = ...,
) -> None: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"on_query_progress",
b"on_query_progress",
"on_query_started",
b"on_query_started",
"on_query_terminated",
b"on_query_terminated",
],
) -> None: ...

global___AddStreamingQueryListenerCommand = AddStreamingQueryListenerCommand

class GetResourcesCommand(google.protobuf.message.Message):
"""Command to get the output of 'SparkContext.resources'"""

Expand Down
Loading