diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 917444604405..d16638e59459 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -156,7 +156,8 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo executeManagerCmd( _.getAddListenerBuilder .setListenerPayload(ByteString.copyFrom(SparkSerDeUtils - .serialize(StreamingListenerPacket(id, listener))))) + .serialize(StreamingListenerPacket(id, listener)))) + .setId(id)) } /** @@ -168,8 +169,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo val id = getIdByListener(listener) executeManagerCmd( _.getRemoveListenerBuilder - .setListenerPayload(ByteString.copyFrom(SparkSerDeUtils - .serialize(StreamingListenerPacket(id, listener))))) + .setId(id)) removeCachedListener(id) } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala index bc778f02480b..f9e6e6864953 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala @@ -294,7 +294,7 @@ class ClientStreamingQuerySuite extends QueryTest with SQLHelper with Logging { spark.sql("DROP TABLE IF EXISTS my_listener_table") } - // List listeners after adding a new listener, length should be 2. + // List listeners after adding a new listener, length should be 1. val listeners = spark.streams.listListeners() assert(listeners.length == 1) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto index 4c4233124d8a..49b25f099bf2 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto @@ -365,6 +365,8 @@ message StreamingQueryManagerCommand { message StreamingQueryListenerCommand { bytes listener_payload = 1; + optional PythonUDF python_listener_payload = 2; + string id = 3; } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 61e5d9de9142..f9a1e44516e8 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -2831,7 +2831,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { StreamingForeachBatchHelper.scalaForeachBatchWrapper(scalaFn, sessionHolder) case StreamingForeachFunction.FunctionCase.FUNCTION_NOT_SET => - throw InvalidPlanInput("Unexpected") // Unreachable + throw InvalidPlanInput("Unexpected foreachBatch function") // Unreachable } writer.foreachBatch(foreachBatchFn) @@ -3074,23 +3074,27 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { respBuilder.setResetTerminated(true) case StreamingQueryManagerCommand.CommandCase.ADD_LISTENER => - val listenerPacket = Utils - .deserialize[StreamingListenerPacket]( - command.getAddListener.getListenerPayload.toByteArray, - Utils.getContextOrSparkClassLoader) - val listener: StreamingQueryListener = listenerPacket.listener - .asInstanceOf[StreamingQueryListener] - val id: String = listenerPacket.id + val listener = if (command.getAddListener.hasPythonListenerPayload) { + new PythonStreamingQueryListener( + transformPythonFunction(command.getAddListener.getPythonListenerPayload), + sessionHolder, + pythonExec) + } else { + val listenerPacket = Utils + .deserialize[StreamingListenerPacket]( + command.getAddListener.getListenerPayload.toByteArray, + Utils.getContextOrSparkClassLoader) + + listenerPacket.listener.asInstanceOf[StreamingQueryListener] + } + + val id = command.getAddListener.getId sessionHolder.cacheListenerById(id, listener) session.streams.addListener(listener) respBuilder.setAddListener(true) case StreamingQueryManagerCommand.CommandCase.REMOVE_LISTENER => - val listenerId = Utils - .deserialize[StreamingListenerPacket]( - command.getRemoveListener.getListenerPayload.toByteArray, - Utils.getContextOrSparkClassLoader) - .id + val listenerId = command.getRemoveListener.getId val listener: StreamingQueryListener = sessionHolder.getListenerOrThrow(listenerId) session.streams.removeListener(listener) sessionHolder.removeCachedListener(listenerId) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala index 9770ac4cee5f..3b9ae483cf1b 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala @@ -18,9 +18,7 @@ package org.apache.spark.sql.connect.planner import java.util.UUID -import org.apache.spark.api.python.PythonRDD -import org.apache.spark.api.python.SimplePythonFunction -import org.apache.spark.api.python.StreamingPythonRunner +import org.apache.spark.api.python.{PythonRDD, SimplePythonFunction, StreamingPythonRunner} import org.apache.spark.internal.Logging import org.apache.spark.sql.DataFrame import org.apache.spark.sql.connect.service.SessionHolder @@ -90,7 +88,10 @@ object StreamingForeachBatchHelper extends Logging { val port = SparkConnectService.localPort val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}" val runner = StreamingPythonRunner(pythonFn, connectUrl) - val (dataOut, dataIn) = runner.init(sessionHolder.sessionId) + val (dataOut, dataIn) = + runner.init( + sessionHolder.sessionId, + "pyspark.sql.connect.streaming.worker.foreachBatch_worker") val foreachBatchRunnerFn: FnArgsWithId => Unit = (args: FnArgsWithId) => { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala new file mode 100644 index 000000000000..d915bc934960 --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.planner + +import org.apache.spark.api.python.{PythonRDD, SimplePythonFunction, StreamingPythonRunner} +import org.apache.spark.sql.connect.service.{SessionHolder, SparkConnectService} +import org.apache.spark.sql.streaming.StreamingQueryListener + +/** + * A helper class for handling StreamingQueryListener related functionality in Spark Connect. Each + * instance of this class starts a python process, inside which has the python handling logic. + * When new a event is received, it is serialized to json, and passed to the python process. + */ +class PythonStreamingQueryListener( + listener: SimplePythonFunction, + sessionHolder: SessionHolder, + pythonExec: String) + extends StreamingQueryListener { + + val port = SparkConnectService.localPort + val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}" + val runner = StreamingPythonRunner(listener, connectUrl) + + val (dataOut, _) = + runner.init(sessionHolder.sessionId, "pyspark.sql.connect.streaming.worker.listener_worker") + + override def onQueryStarted(event: StreamingQueryListener.QueryStartedEvent): Unit = { + PythonRDD.writeUTF(event.json, dataOut) + dataOut.writeInt(0) + dataOut.flush() + } + + override def onQueryProgress(event: StreamingQueryListener.QueryProgressEvent): Unit = { + PythonRDD.writeUTF(event.json, dataOut) + dataOut.writeInt(1) + dataOut.flush() + } + + override def onQueryIdle(event: StreamingQueryListener.QueryIdleEvent): Unit = { + PythonRDD.writeUTF(event.json, dataOut) + dataOut.writeInt(2) + dataOut.flush() + } + + override def onQueryTerminated(event: StreamingQueryListener.QueryTerminatedEvent): Unit = { + PythonRDD.writeUTF(event.json, dataOut) + dataOut.writeInt(3) + dataOut.flush() + } + + // TODO(SPARK-44433)(SPARK-44516): Improve termination of Processes. + // Similar to foreachBatch when we need to exit the process when the query ends. + // In listener semantics, we need to exit the process when removeListener is called. +} diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 6039f8d232b4..d5d97b74d11f 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -110,9 +110,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } } - /** Creates a Python worker with `pyspark.streaming_worker` module. */ - def createStreamingWorker(): (Socket, Option[Int]) = { - createSimpleWorker("pyspark.streaming_worker") + /** Creates a Python worker with streaming worker module. */ + def createStreamingWorker(streamingWorkerModule: String): (Socket, Option[Int]) = { + createSimpleWorker(streamingWorkerModule) } /** diff --git a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala index faf462a1990d..c02871ee1451 100644 --- a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala @@ -48,9 +48,8 @@ private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: Str * Initializes the Python worker for streaming functions. Sets up Spark Connect session * to be used with the functions. */ - def init(sessionId: String): (DataOutputStream, DataInputStream) = { + def init(sessionId: String, workerModule: String): (DataOutputStream, DataInputStream) = { logInfo(s"Initializing Python runner (session: $sessionId ,pythonExec: $pythonExec") - val env = SparkEnv.get val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") @@ -62,7 +61,7 @@ private[spark] class StreamingPythonRunner(func: PythonFunction, connectUrl: Str envVars.put("SPARK_CONNECT_LOCAL_URL", connectUrl) val pythonWorkerFactory = new PythonWorkerFactory(pythonExec, envVars.asScala.toMap) - val (worker: Socket, _) = pythonWorkerFactory.createStreamingWorker() + val (worker: Socket, _) = pythonWorkerFactory.createStreamingWorker(workerModule) val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 79c3f8f26b1b..4005c317e628 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -863,6 +863,7 @@ def __hash__(self): "pyspark.sql.tests.connect.client.test_artifact", "pyspark.sql.tests.connect.client.test_client", "pyspark.sql.tests.connect.streaming.test_parity_streaming", + "pyspark.sql.tests.connect.streaming.test_parity_listener", "pyspark.sql.tests.connect.streaming.test_parity_foreach", "pyspark.sql.tests.connect.streaming.test_parity_foreachBatch", "pyspark.sql.tests.connect.test_parity_pandas_grouped_map_with_state", diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py b/python/pyspark/sql/connect/proto/commands_pb2.py index 6f03d80e6696..90911e382bf1 100644 --- a/python/pyspark/sql/connect/proto/commands_pb2.py +++ b/python/pyspark/sql/connect/proto/commands_pb2.py @@ -35,7 +35,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\xf5\x07\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x18\x03 \x01(\x0b\x32).spark.connect.CreateDataFrameViewCommandH\x00R\x13\x63reateDataframeView\x12O\n\x12write_operation_v2\x18\x04 \x01(\x0b\x32\x1f.spark.connect.WriteOperationV2H\x00R\x10writeOperationV2\x12<\n\x0bsql_command\x18\x05 \x01(\x0b\x32\x19.spark.connect.SqlCommandH\x00R\nsqlCommand\x12k\n\x1cwrite_stream_operation_start\x18\x06 \x01(\x0b\x32(.spark.connect.WriteStreamOperationStartH\x00R\x19writeStreamOperationStart\x12^\n\x17streaming_query_command\x18\x07 \x01(\x0b\x32$.spark.connect.StreamingQueryCommandH\x00R\x15streamingQueryCommand\x12X\n\x15get_resources_command\x18\x08 \x01(\x0b\x32".spark.connect.GetResourcesCommandH\x00R\x13getResourcesCommand\x12t\n\x1fstreaming_query_manager_command\x18\t \x01(\x0b\x32+.spark.connect.StreamingQueryManagerCommandH\x00R\x1cstreamingQueryManagerCommand\x12m\n\x17register_table_function\x18\n \x01(\x0b\x32\x33.spark.connect.CommonInlineUserDefinedTableFunctionH\x00R\x15registerTableFunction\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x0e\n\x0c\x63ommand_type"\xf1\x01\n\nSqlCommand\x12\x10\n\x03sql\x18\x01 \x01(\tR\x03sql\x12\x37\n\x04\x61rgs\x18\x02 \x03(\x0b\x32#.spark.connect.SqlCommand.ArgsEntryR\x04\x61rgs\x12<\n\x08pos_args\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x07posArgs\x1aZ\n\tArgsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05value:\x02\x38\x01"\x96\x01\n\x1a\x43reateDataFrameViewCommand\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x1b\n\tis_global\x18\x03 \x01(\x08R\x08isGlobal\x12\x18\n\x07replace\x18\x04 \x01(\x08R\x07replace"\x9b\x08\n\x0eWriteOperation\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1b\n\x06source\x18\x02 \x01(\tH\x01R\x06source\x88\x01\x01\x12\x14\n\x04path\x18\x03 \x01(\tH\x00R\x04path\x12?\n\x05table\x18\x04 \x01(\x0b\x32\'.spark.connect.WriteOperation.SaveTableH\x00R\x05table\x12:\n\x04mode\x18\x05 \x01(\x0e\x32&.spark.connect.WriteOperation.SaveModeR\x04mode\x12*\n\x11sort_column_names\x18\x06 \x03(\tR\x0fsortColumnNames\x12\x31\n\x14partitioning_columns\x18\x07 \x03(\tR\x13partitioningColumns\x12\x43\n\tbucket_by\x18\x08 \x01(\x0b\x32&.spark.connect.WriteOperation.BucketByR\x08\x62ucketBy\x12\x44\n\x07options\x18\t \x03(\x0b\x32*.spark.connect.WriteOperation.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x82\x02\n\tSaveTable\x12\x1d\n\ntable_name\x18\x01 \x01(\tR\ttableName\x12X\n\x0bsave_method\x18\x02 \x01(\x0e\x32\x37.spark.connect.WriteOperation.SaveTable.TableSaveMethodR\nsaveMethod"|\n\x0fTableSaveMethod\x12!\n\x1dTABLE_SAVE_METHOD_UNSPECIFIED\x10\x00\x12#\n\x1fTABLE_SAVE_METHOD_SAVE_AS_TABLE\x10\x01\x12!\n\x1dTABLE_SAVE_METHOD_INSERT_INTO\x10\x02\x1a[\n\x08\x42ucketBy\x12.\n\x13\x62ucket_column_names\x18\x01 \x03(\tR\x11\x62ucketColumnNames\x12\x1f\n\x0bnum_buckets\x18\x02 \x01(\x05R\nnumBuckets"\x89\x01\n\x08SaveMode\x12\x19\n\x15SAVE_MODE_UNSPECIFIED\x10\x00\x12\x14\n\x10SAVE_MODE_APPEND\x10\x01\x12\x17\n\x13SAVE_MODE_OVERWRITE\x10\x02\x12\x1d\n\x19SAVE_MODE_ERROR_IF_EXISTS\x10\x03\x12\x14\n\x10SAVE_MODE_IGNORE\x10\x04\x42\x0b\n\tsave_typeB\t\n\x07_source"\xad\x06\n\x10WriteOperationV2\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1d\n\ntable_name\x18\x02 \x01(\tR\ttableName\x12\x1f\n\x08provider\x18\x03 \x01(\tH\x00R\x08provider\x88\x01\x01\x12L\n\x14partitioning_columns\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13partitioningColumns\x12\x46\n\x07options\x18\x05 \x03(\x0b\x32,.spark.connect.WriteOperationV2.OptionsEntryR\x07options\x12_\n\x10table_properties\x18\x06 \x03(\x0b\x32\x34.spark.connect.WriteOperationV2.TablePropertiesEntryR\x0ftableProperties\x12\x38\n\x04mode\x18\x07 \x01(\x0e\x32$.spark.connect.WriteOperationV2.ModeR\x04mode\x12J\n\x13overwrite_condition\x18\x08 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x12overwriteCondition\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"\x9f\x01\n\x04Mode\x12\x14\n\x10MODE_UNSPECIFIED\x10\x00\x12\x0f\n\x0bMODE_CREATE\x10\x01\x12\x12\n\x0eMODE_OVERWRITE\x10\x02\x12\x1d\n\x19MODE_OVERWRITE_PARTITIONS\x10\x03\x12\x0f\n\x0bMODE_APPEND\x10\x04\x12\x10\n\x0cMODE_REPLACE\x10\x05\x12\x1a\n\x16MODE_CREATE_OR_REPLACE\x10\x06\x42\x0b\n\t_provider"\xa0\x06\n\x19WriteStreamOperationStart\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06\x66ormat\x18\x02 \x01(\tR\x06\x66ormat\x12O\n\x07options\x18\x03 \x03(\x0b\x32\x35.spark.connect.WriteStreamOperationStart.OptionsEntryR\x07options\x12:\n\x19partitioning_column_names\x18\x04 \x03(\tR\x17partitioningColumnNames\x12:\n\x18processing_time_interval\x18\x05 \x01(\tH\x00R\x16processingTimeInterval\x12%\n\ravailable_now\x18\x06 \x01(\x08H\x00R\x0c\x61vailableNow\x12\x14\n\x04once\x18\x07 \x01(\x08H\x00R\x04once\x12\x46\n\x1e\x63ontinuous_checkpoint_interval\x18\x08 \x01(\tH\x00R\x1c\x63ontinuousCheckpointInterval\x12\x1f\n\x0boutput_mode\x18\t \x01(\tR\noutputMode\x12\x1d\n\nquery_name\x18\n \x01(\tR\tqueryName\x12\x14\n\x04path\x18\x0b \x01(\tH\x01R\x04path\x12\x1f\n\ntable_name\x18\x0c \x01(\tH\x01R\ttableName\x12N\n\x0e\x66oreach_writer\x18\r \x01(\x0b\x32\'.spark.connect.StreamingForeachFunctionR\rforeachWriter\x12L\n\rforeach_batch\x18\x0e \x01(\x0b\x32\'.spark.connect.StreamingForeachFunctionR\x0c\x66oreachBatch\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07triggerB\x12\n\x10sink_destination"\xb3\x01\n\x18StreamingForeachFunction\x12\x43\n\x0fpython_function\x18\x01 \x01(\x0b\x32\x18.spark.connect.PythonUDFH\x00R\x0epythonFunction\x12\x46\n\x0escala_function\x18\x02 \x01(\x0b\x32\x1d.spark.connect.ScalarScalaUDFH\x00R\rscalaFunctionB\n\n\x08\x66unction"y\n\x1fWriteStreamOperationStartResult\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name"A\n\x18StreamingQueryInstanceId\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x15\n\x06run_id\x18\x02 \x01(\tR\x05runId"\xf8\x04\n\x15StreamingQueryCommand\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12\x18\n\x06status\x18\x02 \x01(\x08H\x00R\x06status\x12%\n\rlast_progress\x18\x03 \x01(\x08H\x00R\x0clastProgress\x12)\n\x0frecent_progress\x18\x04 \x01(\x08H\x00R\x0erecentProgress\x12\x14\n\x04stop\x18\x05 \x01(\x08H\x00R\x04stop\x12\x34\n\x15process_all_available\x18\x06 \x01(\x08H\x00R\x13processAllAvailable\x12O\n\x07\x65xplain\x18\x07 \x01(\x0b\x32\x33.spark.connect.StreamingQueryCommand.ExplainCommandH\x00R\x07\x65xplain\x12\x1e\n\texception\x18\x08 \x01(\x08H\x00R\texception\x12k\n\x11\x61wait_termination\x18\t \x01(\x0b\x32<.spark.connect.StreamingQueryCommand.AwaitTerminationCommandH\x00R\x10\x61waitTermination\x1a,\n\x0e\x45xplainCommand\x12\x1a\n\x08\x65xtended\x18\x01 \x01(\x08R\x08\x65xtended\x1aL\n\x17\x41waitTerminationCommand\x12"\n\ntimeout_ms\x18\x02 \x01(\x03H\x00R\ttimeoutMs\x88\x01\x01\x42\r\n\x0b_timeout_msB\t\n\x07\x63ommand"\xf5\x08\n\x1bStreamingQueryCommandResult\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12Q\n\x06status\x18\x02 \x01(\x0b\x32\x37.spark.connect.StreamingQueryCommandResult.StatusResultH\x00R\x06status\x12j\n\x0frecent_progress\x18\x03 \x01(\x0b\x32?.spark.connect.StreamingQueryCommandResult.RecentProgressResultH\x00R\x0erecentProgress\x12T\n\x07\x65xplain\x18\x04 \x01(\x0b\x32\x38.spark.connect.StreamingQueryCommandResult.ExplainResultH\x00R\x07\x65xplain\x12Z\n\texception\x18\x05 \x01(\x0b\x32:.spark.connect.StreamingQueryCommandResult.ExceptionResultH\x00R\texception\x12p\n\x11\x61wait_termination\x18\x06 \x01(\x0b\x32\x41.spark.connect.StreamingQueryCommandResult.AwaitTerminationResultH\x00R\x10\x61waitTermination\x1a\xaa\x01\n\x0cStatusResult\x12%\n\x0estatus_message\x18\x01 \x01(\tR\rstatusMessage\x12*\n\x11is_data_available\x18\x02 \x01(\x08R\x0fisDataAvailable\x12*\n\x11is_trigger_active\x18\x03 \x01(\x08R\x0fisTriggerActive\x12\x1b\n\tis_active\x18\x04 \x01(\x08R\x08isActive\x1aH\n\x14RecentProgressResult\x12\x30\n\x14recent_progress_json\x18\x05 \x03(\tR\x12recentProgressJson\x1a\'\n\rExplainResult\x12\x16\n\x06result\x18\x01 \x01(\tR\x06result\x1a\xc5\x01\n\x0f\x45xceptionResult\x12\x30\n\x11\x65xception_message\x18\x01 \x01(\tH\x00R\x10\x65xceptionMessage\x88\x01\x01\x12$\n\x0b\x65rror_class\x18\x02 \x01(\tH\x01R\nerrorClass\x88\x01\x01\x12$\n\x0bstack_trace\x18\x03 \x01(\tH\x02R\nstackTrace\x88\x01\x01\x42\x14\n\x12_exception_messageB\x0e\n\x0c_error_classB\x0e\n\x0c_stack_trace\x1a\x38\n\x16\x41waitTerminationResult\x12\x1e\n\nterminated\x18\x01 \x01(\x08R\nterminatedB\r\n\x0bresult_type"\xb9\x05\n\x1cStreamingQueryManagerCommand\x12\x18\n\x06\x61\x63tive\x18\x01 \x01(\x08H\x00R\x06\x61\x63tive\x12\x1d\n\tget_query\x18\x02 \x01(\tH\x00R\x08getQuery\x12|\n\x15\x61wait_any_termination\x18\x03 \x01(\x0b\x32\x46.spark.connect.StreamingQueryManagerCommand.AwaitAnyTerminationCommandH\x00R\x13\x61waitAnyTermination\x12+\n\x10reset_terminated\x18\x04 \x01(\x08H\x00R\x0fresetTerminated\x12n\n\x0c\x61\x64\x64_listener\x18\x05 \x01(\x0b\x32I.spark.connect.StreamingQueryManagerCommand.StreamingQueryListenerCommandH\x00R\x0b\x61\x64\x64Listener\x12t\n\x0fremove_listener\x18\x06 \x01(\x0b\x32I.spark.connect.StreamingQueryManagerCommand.StreamingQueryListenerCommandH\x00R\x0eremoveListener\x12\'\n\x0elist_listeners\x18\x07 \x01(\x08H\x00R\rlistListeners\x1aO\n\x1a\x41waitAnyTerminationCommand\x12"\n\ntimeout_ms\x18\x01 \x01(\x03H\x00R\ttimeoutMs\x88\x01\x01\x42\r\n\x0b_timeout_ms\x1aJ\n\x1dStreamingQueryListenerCommand\x12)\n\x10listener_payload\x18\x01 \x01(\x0cR\x0flistenerPayloadB\t\n\x07\x63ommand"\xb4\x08\n"StreamingQueryManagerCommandResult\x12X\n\x06\x61\x63tive\x18\x01 \x01(\x0b\x32>.spark.connect.StreamingQueryManagerCommandResult.ActiveResultH\x00R\x06\x61\x63tive\x12`\n\x05query\x18\x02 \x01(\x0b\x32H.spark.connect.StreamingQueryManagerCommandResult.StreamingQueryInstanceH\x00R\x05query\x12\x81\x01\n\x15\x61wait_any_termination\x18\x03 \x01(\x0b\x32K.spark.connect.StreamingQueryManagerCommandResult.AwaitAnyTerminationResultH\x00R\x13\x61waitAnyTermination\x12+\n\x10reset_terminated\x18\x04 \x01(\x08H\x00R\x0fresetTerminated\x12#\n\x0c\x61\x64\x64_listener\x18\x05 \x01(\x08H\x00R\x0b\x61\x64\x64Listener\x12)\n\x0fremove_listener\x18\x06 \x01(\x08H\x00R\x0eremoveListener\x12{\n\x0elist_listeners\x18\x07 \x01(\x0b\x32R.spark.connect.StreamingQueryManagerCommandResult.ListStreamingQueryListenerResultH\x00R\rlistListeners\x1a\x7f\n\x0c\x41\x63tiveResult\x12o\n\x0e\x61\x63tive_queries\x18\x01 \x03(\x0b\x32H.spark.connect.StreamingQueryManagerCommandResult.StreamingQueryInstanceR\ractiveQueries\x1as\n\x16StreamingQueryInstance\x12\x37\n\x02id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x02id\x12\x17\n\x04name\x18\x02 \x01(\tH\x00R\x04name\x88\x01\x01\x42\x07\n\x05_name\x1a;\n\x19\x41waitAnyTerminationResult\x12\x1e\n\nterminated\x18\x01 \x01(\x08R\nterminated\x1aK\n\x1eStreamingQueryListenerInstance\x12)\n\x10listener_payload\x18\x01 \x01(\x0cR\x0flistenerPayload\x1a\x45\n ListStreamingQueryListenerResult\x12!\n\x0clistener_ids\x18\x01 \x03(\tR\x0blistenerIdsB\r\n\x0bresult_type"\x15\n\x13GetResourcesCommand"\xd4\x01\n\x19GetResourcesCommandResult\x12U\n\tresources\x18\x01 \x03(\x0b\x32\x37.spark.connect.GetResourcesCommandResult.ResourcesEntryR\tresources\x1a`\n\x0eResourcesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x38\n\x05value\x18\x02 \x01(\x0b\x32".spark.connect.ResourceInformationR\x05value:\x02\x38\x01\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\xf5\x07\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x18\x03 \x01(\x0b\x32).spark.connect.CreateDataFrameViewCommandH\x00R\x13\x63reateDataframeView\x12O\n\x12write_operation_v2\x18\x04 \x01(\x0b\x32\x1f.spark.connect.WriteOperationV2H\x00R\x10writeOperationV2\x12<\n\x0bsql_command\x18\x05 \x01(\x0b\x32\x19.spark.connect.SqlCommandH\x00R\nsqlCommand\x12k\n\x1cwrite_stream_operation_start\x18\x06 \x01(\x0b\x32(.spark.connect.WriteStreamOperationStartH\x00R\x19writeStreamOperationStart\x12^\n\x17streaming_query_command\x18\x07 \x01(\x0b\x32$.spark.connect.StreamingQueryCommandH\x00R\x15streamingQueryCommand\x12X\n\x15get_resources_command\x18\x08 \x01(\x0b\x32".spark.connect.GetResourcesCommandH\x00R\x13getResourcesCommand\x12t\n\x1fstreaming_query_manager_command\x18\t \x01(\x0b\x32+.spark.connect.StreamingQueryManagerCommandH\x00R\x1cstreamingQueryManagerCommand\x12m\n\x17register_table_function\x18\n \x01(\x0b\x32\x33.spark.connect.CommonInlineUserDefinedTableFunctionH\x00R\x15registerTableFunction\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x0e\n\x0c\x63ommand_type"\xf1\x01\n\nSqlCommand\x12\x10\n\x03sql\x18\x01 \x01(\tR\x03sql\x12\x37\n\x04\x61rgs\x18\x02 \x03(\x0b\x32#.spark.connect.SqlCommand.ArgsEntryR\x04\x61rgs\x12<\n\x08pos_args\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x07posArgs\x1aZ\n\tArgsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05value:\x02\x38\x01"\x96\x01\n\x1a\x43reateDataFrameViewCommand\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x1b\n\tis_global\x18\x03 \x01(\x08R\x08isGlobal\x12\x18\n\x07replace\x18\x04 \x01(\x08R\x07replace"\x9b\x08\n\x0eWriteOperation\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1b\n\x06source\x18\x02 \x01(\tH\x01R\x06source\x88\x01\x01\x12\x14\n\x04path\x18\x03 \x01(\tH\x00R\x04path\x12?\n\x05table\x18\x04 \x01(\x0b\x32\'.spark.connect.WriteOperation.SaveTableH\x00R\x05table\x12:\n\x04mode\x18\x05 \x01(\x0e\x32&.spark.connect.WriteOperation.SaveModeR\x04mode\x12*\n\x11sort_column_names\x18\x06 \x03(\tR\x0fsortColumnNames\x12\x31\n\x14partitioning_columns\x18\x07 \x03(\tR\x13partitioningColumns\x12\x43\n\tbucket_by\x18\x08 \x01(\x0b\x32&.spark.connect.WriteOperation.BucketByR\x08\x62ucketBy\x12\x44\n\x07options\x18\t \x03(\x0b\x32*.spark.connect.WriteOperation.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x82\x02\n\tSaveTable\x12\x1d\n\ntable_name\x18\x01 \x01(\tR\ttableName\x12X\n\x0bsave_method\x18\x02 \x01(\x0e\x32\x37.spark.connect.WriteOperation.SaveTable.TableSaveMethodR\nsaveMethod"|\n\x0fTableSaveMethod\x12!\n\x1dTABLE_SAVE_METHOD_UNSPECIFIED\x10\x00\x12#\n\x1fTABLE_SAVE_METHOD_SAVE_AS_TABLE\x10\x01\x12!\n\x1dTABLE_SAVE_METHOD_INSERT_INTO\x10\x02\x1a[\n\x08\x42ucketBy\x12.\n\x13\x62ucket_column_names\x18\x01 \x03(\tR\x11\x62ucketColumnNames\x12\x1f\n\x0bnum_buckets\x18\x02 \x01(\x05R\nnumBuckets"\x89\x01\n\x08SaveMode\x12\x19\n\x15SAVE_MODE_UNSPECIFIED\x10\x00\x12\x14\n\x10SAVE_MODE_APPEND\x10\x01\x12\x17\n\x13SAVE_MODE_OVERWRITE\x10\x02\x12\x1d\n\x19SAVE_MODE_ERROR_IF_EXISTS\x10\x03\x12\x14\n\x10SAVE_MODE_IGNORE\x10\x04\x42\x0b\n\tsave_typeB\t\n\x07_source"\xad\x06\n\x10WriteOperationV2\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1d\n\ntable_name\x18\x02 \x01(\tR\ttableName\x12\x1f\n\x08provider\x18\x03 \x01(\tH\x00R\x08provider\x88\x01\x01\x12L\n\x14partitioning_columns\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13partitioningColumns\x12\x46\n\x07options\x18\x05 \x03(\x0b\x32,.spark.connect.WriteOperationV2.OptionsEntryR\x07options\x12_\n\x10table_properties\x18\x06 \x03(\x0b\x32\x34.spark.connect.WriteOperationV2.TablePropertiesEntryR\x0ftableProperties\x12\x38\n\x04mode\x18\x07 \x01(\x0e\x32$.spark.connect.WriteOperationV2.ModeR\x04mode\x12J\n\x13overwrite_condition\x18\x08 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x12overwriteCondition\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"\x9f\x01\n\x04Mode\x12\x14\n\x10MODE_UNSPECIFIED\x10\x00\x12\x0f\n\x0bMODE_CREATE\x10\x01\x12\x12\n\x0eMODE_OVERWRITE\x10\x02\x12\x1d\n\x19MODE_OVERWRITE_PARTITIONS\x10\x03\x12\x0f\n\x0bMODE_APPEND\x10\x04\x12\x10\n\x0cMODE_REPLACE\x10\x05\x12\x1a\n\x16MODE_CREATE_OR_REPLACE\x10\x06\x42\x0b\n\t_provider"\xa0\x06\n\x19WriteStreamOperationStart\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06\x66ormat\x18\x02 \x01(\tR\x06\x66ormat\x12O\n\x07options\x18\x03 \x03(\x0b\x32\x35.spark.connect.WriteStreamOperationStart.OptionsEntryR\x07options\x12:\n\x19partitioning_column_names\x18\x04 \x03(\tR\x17partitioningColumnNames\x12:\n\x18processing_time_interval\x18\x05 \x01(\tH\x00R\x16processingTimeInterval\x12%\n\ravailable_now\x18\x06 \x01(\x08H\x00R\x0c\x61vailableNow\x12\x14\n\x04once\x18\x07 \x01(\x08H\x00R\x04once\x12\x46\n\x1e\x63ontinuous_checkpoint_interval\x18\x08 \x01(\tH\x00R\x1c\x63ontinuousCheckpointInterval\x12\x1f\n\x0boutput_mode\x18\t \x01(\tR\noutputMode\x12\x1d\n\nquery_name\x18\n \x01(\tR\tqueryName\x12\x14\n\x04path\x18\x0b \x01(\tH\x01R\x04path\x12\x1f\n\ntable_name\x18\x0c \x01(\tH\x01R\ttableName\x12N\n\x0e\x66oreach_writer\x18\r \x01(\x0b\x32\'.spark.connect.StreamingForeachFunctionR\rforeachWriter\x12L\n\rforeach_batch\x18\x0e \x01(\x0b\x32\'.spark.connect.StreamingForeachFunctionR\x0c\x66oreachBatch\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07triggerB\x12\n\x10sink_destination"\xb3\x01\n\x18StreamingForeachFunction\x12\x43\n\x0fpython_function\x18\x01 \x01(\x0b\x32\x18.spark.connect.PythonUDFH\x00R\x0epythonFunction\x12\x46\n\x0escala_function\x18\x02 \x01(\x0b\x32\x1d.spark.connect.ScalarScalaUDFH\x00R\rscalaFunctionB\n\n\x08\x66unction"y\n\x1fWriteStreamOperationStartResult\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name"A\n\x18StreamingQueryInstanceId\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x15\n\x06run_id\x18\x02 \x01(\tR\x05runId"\xf8\x04\n\x15StreamingQueryCommand\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12\x18\n\x06status\x18\x02 \x01(\x08H\x00R\x06status\x12%\n\rlast_progress\x18\x03 \x01(\x08H\x00R\x0clastProgress\x12)\n\x0frecent_progress\x18\x04 \x01(\x08H\x00R\x0erecentProgress\x12\x14\n\x04stop\x18\x05 \x01(\x08H\x00R\x04stop\x12\x34\n\x15process_all_available\x18\x06 \x01(\x08H\x00R\x13processAllAvailable\x12O\n\x07\x65xplain\x18\x07 \x01(\x0b\x32\x33.spark.connect.StreamingQueryCommand.ExplainCommandH\x00R\x07\x65xplain\x12\x1e\n\texception\x18\x08 \x01(\x08H\x00R\texception\x12k\n\x11\x61wait_termination\x18\t \x01(\x0b\x32<.spark.connect.StreamingQueryCommand.AwaitTerminationCommandH\x00R\x10\x61waitTermination\x1a,\n\x0e\x45xplainCommand\x12\x1a\n\x08\x65xtended\x18\x01 \x01(\x08R\x08\x65xtended\x1aL\n\x17\x41waitTerminationCommand\x12"\n\ntimeout_ms\x18\x02 \x01(\x03H\x00R\ttimeoutMs\x88\x01\x01\x42\r\n\x0b_timeout_msB\t\n\x07\x63ommand"\xf5\x08\n\x1bStreamingQueryCommandResult\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12Q\n\x06status\x18\x02 \x01(\x0b\x32\x37.spark.connect.StreamingQueryCommandResult.StatusResultH\x00R\x06status\x12j\n\x0frecent_progress\x18\x03 \x01(\x0b\x32?.spark.connect.StreamingQueryCommandResult.RecentProgressResultH\x00R\x0erecentProgress\x12T\n\x07\x65xplain\x18\x04 \x01(\x0b\x32\x38.spark.connect.StreamingQueryCommandResult.ExplainResultH\x00R\x07\x65xplain\x12Z\n\texception\x18\x05 \x01(\x0b\x32:.spark.connect.StreamingQueryCommandResult.ExceptionResultH\x00R\texception\x12p\n\x11\x61wait_termination\x18\x06 \x01(\x0b\x32\x41.spark.connect.StreamingQueryCommandResult.AwaitTerminationResultH\x00R\x10\x61waitTermination\x1a\xaa\x01\n\x0cStatusResult\x12%\n\x0estatus_message\x18\x01 \x01(\tR\rstatusMessage\x12*\n\x11is_data_available\x18\x02 \x01(\x08R\x0fisDataAvailable\x12*\n\x11is_trigger_active\x18\x03 \x01(\x08R\x0fisTriggerActive\x12\x1b\n\tis_active\x18\x04 \x01(\x08R\x08isActive\x1aH\n\x14RecentProgressResult\x12\x30\n\x14recent_progress_json\x18\x05 \x03(\tR\x12recentProgressJson\x1a\'\n\rExplainResult\x12\x16\n\x06result\x18\x01 \x01(\tR\x06result\x1a\xc5\x01\n\x0f\x45xceptionResult\x12\x30\n\x11\x65xception_message\x18\x01 \x01(\tH\x00R\x10\x65xceptionMessage\x88\x01\x01\x12$\n\x0b\x65rror_class\x18\x02 \x01(\tH\x01R\nerrorClass\x88\x01\x01\x12$\n\x0bstack_trace\x18\x03 \x01(\tH\x02R\nstackTrace\x88\x01\x01\x42\x14\n\x12_exception_messageB\x0e\n\x0c_error_classB\x0e\n\x0c_stack_trace\x1a\x38\n\x16\x41waitTerminationResult\x12\x1e\n\nterminated\x18\x01 \x01(\x08R\nterminatedB\r\n\x0bresult_type"\xbd\x06\n\x1cStreamingQueryManagerCommand\x12\x18\n\x06\x61\x63tive\x18\x01 \x01(\x08H\x00R\x06\x61\x63tive\x12\x1d\n\tget_query\x18\x02 \x01(\tH\x00R\x08getQuery\x12|\n\x15\x61wait_any_termination\x18\x03 \x01(\x0b\x32\x46.spark.connect.StreamingQueryManagerCommand.AwaitAnyTerminationCommandH\x00R\x13\x61waitAnyTermination\x12+\n\x10reset_terminated\x18\x04 \x01(\x08H\x00R\x0fresetTerminated\x12n\n\x0c\x61\x64\x64_listener\x18\x05 \x01(\x0b\x32I.spark.connect.StreamingQueryManagerCommand.StreamingQueryListenerCommandH\x00R\x0b\x61\x64\x64Listener\x12t\n\x0fremove_listener\x18\x06 \x01(\x0b\x32I.spark.connect.StreamingQueryManagerCommand.StreamingQueryListenerCommandH\x00R\x0eremoveListener\x12\'\n\x0elist_listeners\x18\x07 \x01(\x08H\x00R\rlistListeners\x1aO\n\x1a\x41waitAnyTerminationCommand\x12"\n\ntimeout_ms\x18\x01 \x01(\x03H\x00R\ttimeoutMs\x88\x01\x01\x42\r\n\x0b_timeout_ms\x1a\xcd\x01\n\x1dStreamingQueryListenerCommand\x12)\n\x10listener_payload\x18\x01 \x01(\x0cR\x0flistenerPayload\x12U\n\x17python_listener_payload\x18\x02 \x01(\x0b\x32\x18.spark.connect.PythonUDFH\x00R\x15pythonListenerPayload\x88\x01\x01\x12\x0e\n\x02id\x18\x03 \x01(\tR\x02idB\x1a\n\x18_python_listener_payloadB\t\n\x07\x63ommand"\xb4\x08\n"StreamingQueryManagerCommandResult\x12X\n\x06\x61\x63tive\x18\x01 \x01(\x0b\x32>.spark.connect.StreamingQueryManagerCommandResult.ActiveResultH\x00R\x06\x61\x63tive\x12`\n\x05query\x18\x02 \x01(\x0b\x32H.spark.connect.StreamingQueryManagerCommandResult.StreamingQueryInstanceH\x00R\x05query\x12\x81\x01\n\x15\x61wait_any_termination\x18\x03 \x01(\x0b\x32K.spark.connect.StreamingQueryManagerCommandResult.AwaitAnyTerminationResultH\x00R\x13\x61waitAnyTermination\x12+\n\x10reset_terminated\x18\x04 \x01(\x08H\x00R\x0fresetTerminated\x12#\n\x0c\x61\x64\x64_listener\x18\x05 \x01(\x08H\x00R\x0b\x61\x64\x64Listener\x12)\n\x0fremove_listener\x18\x06 \x01(\x08H\x00R\x0eremoveListener\x12{\n\x0elist_listeners\x18\x07 \x01(\x0b\x32R.spark.connect.StreamingQueryManagerCommandResult.ListStreamingQueryListenerResultH\x00R\rlistListeners\x1a\x7f\n\x0c\x41\x63tiveResult\x12o\n\x0e\x61\x63tive_queries\x18\x01 \x03(\x0b\x32H.spark.connect.StreamingQueryManagerCommandResult.StreamingQueryInstanceR\ractiveQueries\x1as\n\x16StreamingQueryInstance\x12\x37\n\x02id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x02id\x12\x17\n\x04name\x18\x02 \x01(\tH\x00R\x04name\x88\x01\x01\x42\x07\n\x05_name\x1a;\n\x19\x41waitAnyTerminationResult\x12\x1e\n\nterminated\x18\x01 \x01(\x08R\nterminated\x1aK\n\x1eStreamingQueryListenerInstance\x12)\n\x10listener_payload\x18\x01 \x01(\x0cR\x0flistenerPayload\x1a\x45\n ListStreamingQueryListenerResult\x12!\n\x0clistener_ids\x18\x01 \x03(\tR\x0blistenerIdsB\r\n\x0bresult_type"\x15\n\x13GetResourcesCommand"\xd4\x01\n\x19GetResourcesCommandResult\x12U\n\tresources\x18\x01 \x03(\x0b\x32\x37.spark.connect.GetResourcesCommandResult.ResourcesEntryR\tresources\x1a`\n\x0eResourcesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x38\n\x05value\x18\x02 \x01(\x0b\x32".spark.connect.ResourceInformationR\x05value:\x02\x38\x01\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -115,27 +115,27 @@ _STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT._serialized_start = 6330 _STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT._serialized_end = 6386 _STREAMINGQUERYMANAGERCOMMAND._serialized_start = 6404 - _STREAMINGQUERYMANAGERCOMMAND._serialized_end = 7101 + _STREAMINGQUERYMANAGERCOMMAND._serialized_end = 7233 _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_start = 6935 _STREAMINGQUERYMANAGERCOMMAND_AWAITANYTERMINATIONCOMMAND._serialized_end = 7014 - _STREAMINGQUERYMANAGERCOMMAND_STREAMINGQUERYLISTENERCOMMAND._serialized_start = 7016 - _STREAMINGQUERYMANAGERCOMMAND_STREAMINGQUERYLISTENERCOMMAND._serialized_end = 7090 - _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_start = 7104 - _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_end = 8180 - _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_start = 7712 - _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_end = 7839 - _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_start = 7841 - _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_end = 7956 - _STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_start = 7958 - _STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_end = 8017 - _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYLISTENERINSTANCE._serialized_start = 8019 - _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYLISTENERINSTANCE._serialized_end = 8094 - _STREAMINGQUERYMANAGERCOMMANDRESULT_LISTSTREAMINGQUERYLISTENERRESULT._serialized_start = 8096 - _STREAMINGQUERYMANAGERCOMMANDRESULT_LISTSTREAMINGQUERYLISTENERRESULT._serialized_end = 8165 - _GETRESOURCESCOMMAND._serialized_start = 8182 - _GETRESOURCESCOMMAND._serialized_end = 8203 - _GETRESOURCESCOMMANDRESULT._serialized_start = 8206 - _GETRESOURCESCOMMANDRESULT._serialized_end = 8418 - _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_start = 8322 - _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_end = 8418 + _STREAMINGQUERYMANAGERCOMMAND_STREAMINGQUERYLISTENERCOMMAND._serialized_start = 7017 + _STREAMINGQUERYMANAGERCOMMAND_STREAMINGQUERYLISTENERCOMMAND._serialized_end = 7222 + _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_start = 7236 + _STREAMINGQUERYMANAGERCOMMANDRESULT._serialized_end = 8312 + _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_start = 7844 + _STREAMINGQUERYMANAGERCOMMANDRESULT_ACTIVERESULT._serialized_end = 7971 + _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_start = 7973 + _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYINSTANCE._serialized_end = 8088 + _STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_start = 8090 + _STREAMINGQUERYMANAGERCOMMANDRESULT_AWAITANYTERMINATIONRESULT._serialized_end = 8149 + _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYLISTENERINSTANCE._serialized_start = 8151 + _STREAMINGQUERYMANAGERCOMMANDRESULT_STREAMINGQUERYLISTENERINSTANCE._serialized_end = 8226 + _STREAMINGQUERYMANAGERCOMMANDRESULT_LISTSTREAMINGQUERYLISTENERRESULT._serialized_start = 8228 + _STREAMINGQUERYMANAGERCOMMANDRESULT_LISTSTREAMINGQUERYLISTENERRESULT._serialized_end = 8297 + _GETRESOURCESCOMMAND._serialized_start = 8314 + _GETRESOURCESCOMMAND._serialized_end = 8335 + _GETRESOURCESCOMMANDRESULT._serialized_start = 8338 + _GETRESOURCESCOMMANDRESULT._serialized_end = 8550 + _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_start = 8454 + _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_end = 8550 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi b/python/pyspark/sql/connect/proto/commands_pb2.pyi index be423ea036e9..f3dca7ab4bb7 100644 --- a/python/pyspark/sql/connect/proto/commands_pb2.pyi +++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi @@ -1372,15 +1372,50 @@ class StreamingQueryManagerCommand(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor LISTENER_PAYLOAD_FIELD_NUMBER: builtins.int + PYTHON_LISTENER_PAYLOAD_FIELD_NUMBER: builtins.int + ID_FIELD_NUMBER: builtins.int listener_payload: builtins.bytes + @property + def python_listener_payload( + self, + ) -> pyspark.sql.connect.proto.expressions_pb2.PythonUDF: ... + id: builtins.str def __init__( self, *, listener_payload: builtins.bytes = ..., + python_listener_payload: pyspark.sql.connect.proto.expressions_pb2.PythonUDF + | None = ..., + id: builtins.str = ..., ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "_python_listener_payload", + b"_python_listener_payload", + "python_listener_payload", + b"python_listener_payload", + ], + ) -> builtins.bool: ... def ClearField( - self, field_name: typing_extensions.Literal["listener_payload", b"listener_payload"] + self, + field_name: typing_extensions.Literal[ + "_python_listener_payload", + b"_python_listener_payload", + "id", + b"id", + "listener_payload", + b"listener_payload", + "python_listener_payload", + b"python_listener_payload", + ], ) -> None: ... + def WhichOneof( + self, + oneof_group: typing_extensions.Literal[ + "_python_listener_payload", b"_python_listener_payload" + ], + ) -> typing_extensions.Literal["python_listener_payload"] | None: ... ACTIVE_FIELD_NUMBER: builtins.int GET_QUERY_FIELD_NUMBER: builtins.int diff --git a/python/pyspark/sql/connect/streaming/query.py b/python/pyspark/sql/connect/streaming/query.py index e5aa881c9906..59e98e7bc30f 100644 --- a/python/pyspark/sql/connect/streaming/query.py +++ b/python/pyspark/sql/connect/streaming/query.py @@ -21,6 +21,9 @@ from pyspark.errors import StreamingQueryException, PySparkValueError import pyspark.sql.connect.proto as pb2 +from pyspark.serializers import CloudPickleSerializer +from pyspark.sql.connect import proto +from pyspark.sql.streaming import StreamingQueryListener from pyspark.sql.streaming.query import ( StreamingQuery as PySparkStreamingQuery, StreamingQueryManager as PySparkStreamingQueryManager, @@ -226,25 +229,27 @@ def resetTerminated(self) -> None: cmd = pb2.StreamingQueryManagerCommand() cmd.reset_terminated = True self._execute_streaming_query_manager_cmd(cmd) - return None resetTerminated.__doc__ = PySparkStreamingQueryManager.resetTerminated.__doc__ - def addListener(self, listener: Any) -> None: - # TODO(SPARK-42941): Change listener type to Connect StreamingQueryListener - # and implement below - raise NotImplementedError("addListener() is not implemented.") + def addListener(self, listener: StreamingQueryListener) -> None: + listener._init_listener_id() + cmd = pb2.StreamingQueryManagerCommand() + expr = proto.PythonUDF() + expr.command = CloudPickleSerializer().dumps(listener) + expr.python_ver = "%d.%d" % sys.version_info[:2] + cmd.add_listener.python_listener_payload.CopyFrom(expr) + cmd.add_listener.id = listener._id + self._execute_streaming_query_manager_cmd(cmd) - # TODO(SPARK-42941): uncomment below - # addListener.__doc__ = PySparkStreamingQueryManager.addListener.__doc__ + addListener.__doc__ = PySparkStreamingQueryManager.addListener.__doc__ - def removeListener(self, listener: Any) -> None: - # TODO(SPARK-42941): Change listener type to Connect StreamingQueryListener - # and implement below - raise NotImplementedError("removeListener() is not implemented.") + def removeListener(self, listener: StreamingQueryListener) -> None: + cmd = pb2.StreamingQueryManagerCommand() + cmd.remove_listener.id = listener._id + self._execute_streaming_query_manager_cmd(cmd) - # TODO(SPARK-42941): uncomment below - # removeListener.__doc__ = PySparkStreamingQueryManager.removeListener.__doc__ + removeListener.__doc__ = PySparkStreamingQueryManager.removeListener.__doc__ def _execute_streaming_query_manager_cmd( self, cmd: pb2.StreamingQueryManagerCommand diff --git a/python/pyspark/sql/connect/streaming/worker/__init__.py b/python/pyspark/sql/connect/streaming/worker/__init__.py new file mode 100644 index 000000000000..a5c980198919 --- /dev/null +++ b/python/pyspark/sql/connect/streaming/worker/__init__.py @@ -0,0 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Spark Connect Streaming Server-side Worker""" diff --git a/python/pyspark/streaming_worker.py b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py similarity index 86% rename from python/pyspark/streaming_worker.py rename to python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py index a818880a9849..054788539f29 100644 --- a/python/pyspark/streaming_worker.py +++ b/python/pyspark/sql/connect/streaming/worker/foreachBatch_worker.py @@ -16,7 +16,8 @@ # """ -A worker for streaming foreachBatch and query listener in Spark Connect. +A worker for streaming foreachBatch in Spark Connect. +Usually this is ran on the driver side of the Spark Connect Server. """ import os @@ -29,20 +30,23 @@ ) from pyspark import worker from pyspark.sql import SparkSession +from typing import IO pickle_ser = CPickleSerializer() utf8_deserializer = UTF8Deserializer() -def main(infile, outfile): # type: ignore[no-untyped-def] - log_name = "Streaming ForeachBatch worker" +def main(infile: IO, outfile: IO) -> None: connect_url = os.environ["SPARK_CONNECT_LOCAL_URL"] session_id = utf8_deserializer.loads(infile) - print(f"{log_name} is starting with url {connect_url} and sessionId {session_id}.") + print( + "Streaming foreachBatch worker is starting with " + f"url {connect_url} and sessionId {session_id}." + ) spark_connect_session = SparkSession.builder.remote(connect_url).getOrCreate() - spark_connect_session._client._session_id = session_id + spark_connect_session._client._session_id = session_id # type: ignore[attr-defined] # TODO(SPARK-44460): Pass credentials. # TODO(SPARK-44461): Enable Process Isolation @@ -52,6 +56,8 @@ def main(infile, outfile): # type: ignore[no-untyped-def] outfile.flush() + log_name = "Streaming ForeachBatch worker" + def process(df_id, batch_id): # type: ignore[no-untyped-def] print(f"{log_name} Started batch {batch_id} with DF id {df_id}") batch_df = spark_connect_session._create_remote_dataframe(df_id) @@ -67,8 +73,6 @@ def process(df_id, batch_id): # type: ignore[no-untyped-def] if __name__ == "__main__": - print("Starting streaming worker") - # Read information about how to connect back to the JVM from the environment. java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] diff --git a/python/pyspark/sql/connect/streaming/worker/listener_worker.py b/python/pyspark/sql/connect/streaming/worker/listener_worker.py new file mode 100644 index 000000000000..8eb310461b6f --- /dev/null +++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py @@ -0,0 +1,95 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +A worker for streaming query listener in Spark Connect. +Usually this is ran on the driver side of the Spark Connect Server. +""" +import os +import json + +from pyspark.java_gateway import local_connect_and_auth +from pyspark.serializers import ( + read_int, + write_int, + UTF8Deserializer, + CPickleSerializer, +) +from pyspark import worker +from pyspark.sql import SparkSession +from typing import IO + +from pyspark.sql.streaming.listener import ( + QueryStartedEvent, + QueryProgressEvent, + QueryTerminatedEvent, + QueryIdleEvent, +) + +pickle_ser = CPickleSerializer() +utf8_deserializer = UTF8Deserializer() + + +def main(infile: IO, outfile: IO) -> None: + connect_url = os.environ["SPARK_CONNECT_LOCAL_URL"] + session_id = utf8_deserializer.loads(infile) + + print( + "Streaming query listener worker is starting with " + f"url {connect_url} and sessionId {session_id}." + ) + + spark_connect_session = SparkSession.builder.remote(connect_url).getOrCreate() + spark_connect_session._client._session_id = session_id # type: ignore[attr-defined] + + # TODO(SPARK-44460): Pass credentials. + # TODO(SPARK-44461): Enable Process Isolation + + listener = worker.read_command(pickle_ser, infile) + write_int(0, outfile) # Indicate successful initialization + + outfile.flush() + + listener._set_spark_session(spark_connect_session) + assert listener.spark == spark_connect_session + + def process(listener_event_str, listener_event_type): # type: ignore[no-untyped-def] + listener_event = json.loads(listener_event_str) + if listener_event_type == 0: + listener.onQueryStarted(QueryStartedEvent.fromJson(listener_event)) + elif listener_event_type == 1: + listener.onQueryProgress(QueryProgressEvent.fromJson(listener_event)) + elif listener_event_type == 2: + listener.onQueryIdle(QueryIdleEvent.fromJson(listener_event)) + elif listener_event_type == 3: + listener.onQueryTerminated(QueryTerminatedEvent.fromJson(listener_event)) + + while True: + event = utf8_deserializer.loads(infile) + event_type = read_int(infile) + process(event, int(event_type)) # TODO(SPARK-44463): Propagate error to the user. + outfile.flush() + + +if __name__ == "__main__": + # Read information about how to connect back to the JVM from the environment. + java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) + auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] + (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + write_int(os.getpid(), sock_file) + sock_file.flush() + main(sock_file, sock_file) diff --git a/python/pyspark/sql/streaming/listener.py b/python/pyspark/sql/streaming/listener.py index 198af0c9cbeb..225ad6d45afb 100644 --- a/python/pyspark/sql/streaming/listener.py +++ b/python/pyspark/sql/streaming/listener.py @@ -62,6 +62,21 @@ class StreamingQueryListener(ABC): >>> spark.streams.addListener(MyListener()) """ + def _set_spark_session( + self, spark: "SparkSession" # type: ignore[name-defined] # noqa: F821 + ) -> None: + self._sparkSession = spark + + @property + def spark(self) -> Optional["SparkSession"]: # type: ignore[name-defined] # noqa: F821 + if hasattr(self, "_sparkSession"): + return self._sparkSession + else: + return None + + def _init_listener_id(self) -> None: + self._id = str(uuid.uuid4()) + @abstractmethod def onQueryStarted(self, event: "QueryStartedEvent") -> None: """ @@ -463,8 +478,8 @@ def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress": timestamp=j["timestamp"], batchId=j["batchId"], batchDuration=j["batchDuration"], - durationMs=dict(j["durationMs"]), - eventTime=dict(j["eventTime"]), + durationMs=dict(j["durationMs"]) if "durationMs" in j else {}, + eventTime=dict(j["eventTime"]) if "eventTime" in j else {}, stateOperators=[StateOperatorProgress.fromJson(s) for s in j["stateOperators"]], sources=[SourceProgress.fromJson(s) for s in j["sources"]], sink=SinkProgress.fromJson(j["sink"]), @@ -474,7 +489,9 @@ def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress": observedMetrics={ k: Row(*row_dict.keys())(*row_dict.values()) # Assume no nested rows for k, row_dict in j["observedMetrics"].items() - }, + } + if "observedMetrics" in j + else {}, ) @property @@ -696,7 +713,7 @@ def fromJson(cls, j: Dict[str, Any]) -> "StateOperatorProgress": numRowsDroppedByWatermark=j["numRowsDroppedByWatermark"], numShufflePartitions=j["numShufflePartitions"], numStateStoreInstances=j["numStateStoreInstances"], - customMetrics=dict(j["customMetrics"]), + customMetrics=dict(j["customMetrics"]) if "customMetrics" in j else {}, ) @property @@ -831,7 +848,7 @@ def fromJson(cls, j: Dict[str, Any]) -> "SourceProgress": numInputRows=j["numInputRows"], inputRowsPerSecond=j["inputRowsPerSecond"], processedRowsPerSecond=j["processedRowsPerSecond"], - metrics=dict(j["metrics"]), + metrics=dict(j["metrics"]) if "metrics" in j else {}, ) @property @@ -951,7 +968,7 @@ def fromJson(cls, j: Dict[str, Any]) -> "SinkProgress": jdict=j, description=j["description"], numOutputRows=j["numOutputRows"], - metrics=j["metrics"], + metrics=dict(j["metrics"]) if "metrics" in j else {}, ) @property diff --git a/python/pyspark/sql/streaming/query.py b/python/pyspark/sql/streaming/query.py index 443e7dbee39b..db104e30755a 100644 --- a/python/pyspark/sql/streaming/query.py +++ b/python/pyspark/sql/streaming/query.py @@ -618,12 +618,24 @@ def addListener(self, listener: StreamingQueryListener) -> None: .. versionadded:: 3.4.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- listener : :class:`StreamingQueryListener` A :class:`StreamingQueryListener` to receive up-calls for life cycle events of :class:`~pyspark.sql.streaming.StreamingQuery`. + Notes + ----- + This function behaves differently in Spark Connect mode. + In Connect, the provided functions doesn't have access to variables defined outside of it. + Also in Connect, you need to use `self.spark` to access spark session. + Using `spark` would throw an exception. + In short, if you want to use spark session inside the listener, + please use `self.spark` in Connect mode, and use `spark` otherwise. + Examples -------- >>> from pyspark.sql.streaming import StreamingQueryListener diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py new file mode 100644 index 000000000000..547462d4da6d --- /dev/null +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_listener.py @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import time + +from pyspark.sql.tests.streaming.test_streaming_listener import StreamingListenerTestsMixin +from pyspark.sql.streaming.listener import StreamingQueryListener, QueryStartedEvent +from pyspark.sql.types import StructType, StructField, StringType +from pyspark.testing.connectutils import ReusedConnectTestCase + + +def get_start_event_schema(): + return StructType( + [ + StructField("id", StringType(), True), + StructField("runId", StringType(), True), + StructField("name", StringType(), True), + StructField("timestamp", StringType(), True), + ] + ) + + +class TestListener(StreamingQueryListener): + def onQueryStarted(self, event): + df = self.spark.createDataFrame( + data=[(str(event.id), str(event.runId), event.name, event.timestamp)], + schema=get_start_event_schema(), + ) + df.write.saveAsTable("listener_start_events") + + def onQueryProgress(self, event): + pass + + def onQueryIdle(self, event): + pass + + def onQueryTerminated(self, event): + pass + + +class StreamingListenerParityTests(StreamingListenerTestsMixin, ReusedConnectTestCase): + def test_listener_events(self): + test_listener = TestListener() + + try: + self.spark.streams.addListener(test_listener) + + df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() + q = df.writeStream.format("noop").queryName("test").start() + + self.assertTrue(q.isActive) + time.sleep(10) + q.stop() + + start_event = QueryStartedEvent.fromJson( + self.spark.read.table("listener_start_events").collect()[0].asDict() + ) + + self.check_start_event(start_event) + + finally: + self.spark.streams.removeListener(test_listener) + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.connect.streaming.test_parity_listener import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index 2bd6d2c66683..cbbdc2955e59 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -33,119 +33,7 @@ from pyspark.testing.sqlutils import ReusedSQLTestCase -class StreamingListenerTests(ReusedSQLTestCase): - def test_number_of_public_methods(self): - msg = ( - "New field or method was detected in JVM side. If you added a new public " - "field or method, implement that in the corresponding Python class too." - "Otherwise, fix the number on the assert here." - ) - - def get_number_of_public_methods(clz): - return len( - self.spark.sparkContext._jvm.org.apache.spark.util.Utils.classForName( - clz, True, False - ).getMethods() - ) - - self.assertEquals( - get_number_of_public_methods( - "org.apache.spark.sql.streaming.StreamingQueryListener$QueryStartedEvent" - ), - 15, - msg, - ) - self.assertEquals( - get_number_of_public_methods( - "org.apache.spark.sql.streaming.StreamingQueryListener$QueryProgressEvent" - ), - 12, - msg, - ) - self.assertEquals( - get_number_of_public_methods( - "org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminatedEvent" - ), - 15, - msg, - ) - self.assertEquals( - get_number_of_public_methods("org.apache.spark.sql.streaming.StreamingQueryProgress"), - 38, - msg, - ) - self.assertEquals( - get_number_of_public_methods("org.apache.spark.sql.streaming.StateOperatorProgress"), - 27, - msg, - ) - self.assertEquals( - get_number_of_public_methods("org.apache.spark.sql.streaming.SourceProgress"), 21, msg - ) - self.assertEquals( - get_number_of_public_methods("org.apache.spark.sql.streaming.SinkProgress"), 19, msg - ) - - def test_listener_events(self): - start_event = None - progress_event = None - terminated_event = None - - class TestListener(StreamingQueryListener): - def onQueryStarted(self, event): - nonlocal start_event - start_event = event - - def onQueryProgress(self, event): - nonlocal progress_event - progress_event = event - - def onQueryIdle(self, event): - pass - - def onQueryTerminated(self, event): - nonlocal terminated_event - terminated_event = event - - test_listener = TestListener() - - try: - self.spark.streams.addListener(test_listener) - - df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() - - # check successful stateful query - df_stateful = df.groupBy().count() # make query stateful - q = ( - df_stateful.writeStream.format("noop") - .queryName("test") - .outputMode("complete") - .start() - ) - self.assertTrue(q.isActive) - time.sleep(10) - q.stop() - - # Make sure all events are empty - self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty() - - self.check_start_event(start_event) - self.check_progress_event(progress_event) - self.check_terminated_event(terminated_event) - - # Check query terminated with exception - from pyspark.sql.functions import col, udf - - bad_udf = udf(lambda x: 1 / 0) - q = df.select(bad_udf(col("value"))).writeStream.format("noop").start() - time.sleep(5) - q.stop() - self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty() - self.check_terminated_event(terminated_event, "ZeroDivisionError") - - finally: - self.spark.streams.removeListener(test_listener) - +class StreamingListenerTestsMixin: def check_start_event(self, event): """Check QueryStartedEvent""" self.assertTrue(isinstance(event, QueryStartedEvent)) @@ -304,6 +192,120 @@ def check_sink_progress(self, progress): self.assertTrue(isinstance(progress.numOutputRows, int)) self.assertTrue(isinstance(progress.metrics, dict)) + +class StreamingListenerTests(StreamingListenerTestsMixin, ReusedSQLTestCase): + def test_number_of_public_methods(self): + msg = ( + "New field or method was detected in JVM side. If you added a new public " + "field or method, implement that in the corresponding Python class too." + "Otherwise, fix the number on the assert here." + ) + + def get_number_of_public_methods(clz): + return len( + self.spark.sparkContext._jvm.org.apache.spark.util.Utils.classForName( + clz, True, False + ).getMethods() + ) + + self.assertEquals( + get_number_of_public_methods( + "org.apache.spark.sql.streaming.StreamingQueryListener$QueryStartedEvent" + ), + 15, + msg, + ) + self.assertEquals( + get_number_of_public_methods( + "org.apache.spark.sql.streaming.StreamingQueryListener$QueryProgressEvent" + ), + 12, + msg, + ) + self.assertEquals( + get_number_of_public_methods( + "org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminatedEvent" + ), + 15, + msg, + ) + self.assertEquals( + get_number_of_public_methods("org.apache.spark.sql.streaming.StreamingQueryProgress"), + 38, + msg, + ) + self.assertEquals( + get_number_of_public_methods("org.apache.spark.sql.streaming.StateOperatorProgress"), + 27, + msg, + ) + self.assertEquals( + get_number_of_public_methods("org.apache.spark.sql.streaming.SourceProgress"), 21, msg + ) + self.assertEquals( + get_number_of_public_methods("org.apache.spark.sql.streaming.SinkProgress"), 19, msg + ) + + def test_listener_events(self): + start_event = None + progress_event = None + terminated_event = None + + class TestListener(StreamingQueryListener): + def onQueryStarted(self, event): + nonlocal start_event + start_event = event + + def onQueryProgress(self, event): + nonlocal progress_event + progress_event = event + + def onQueryIdle(self, event): + pass + + def onQueryTerminated(self, event): + nonlocal terminated_event + terminated_event = event + + test_listener = TestListener() + + try: + self.spark.streams.addListener(test_listener) + + df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() + + # check successful stateful query + df_stateful = df.groupBy().count() # make query stateful + q = ( + df_stateful.writeStream.format("noop") + .queryName("test") + .outputMode("complete") + .start() + ) + self.assertTrue(q.isActive) + time.sleep(10) + q.stop() + + # Make sure all events are empty + self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty() + + self.check_start_event(start_event) + self.check_progress_event(progress_event) + self.check_terminated_event(terminated_event) + + # Check query terminated with exception + from pyspark.sql.functions import col, udf + + bad_udf = udf(lambda x: 1 / 0) + q = df.select(bad_udf(col("value"))).writeStream.format("noop").start() + time.sleep(5) + q.stop() + self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty() + self.check_terminated_event(terminated_event, "ZeroDivisionError") + + finally: + self.spark.streams.removeListener(test_listener) + def test_remove_listener(self): # SPARK-38804: Test StreamingQueryManager.removeListener class TestListener(StreamingQueryListener): diff --git a/python/setup.py b/python/setup.py index a0297d4f9dcf..11f82cfec378 100755 --- a/python/setup.py +++ b/python/setup.py @@ -250,6 +250,7 @@ def run(self): "pyspark.sql.protobuf", "pyspark.sql.streaming", "pyspark.streaming", + "pyspark.sql.connect.streaming.worker", "pyspark.bin", "pyspark.sbin", "pyspark.jars",