diff --git a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto index 2195a713ed20..2b648bf0f9a5 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto @@ -259,8 +259,10 @@ message StreamingQueryCommand { bool process_all_available = 6; // explain() API. Returns logical and physical plans. ExplainCommand explain = 7; - - // TODO(SPARK-42960) Add more commands: await_termination(), exception() etc. + // exception() API. Returns the exception in the query if any. + bool exception = 8; + // awaitTermination() API. Waits for the termination of the query. + AwaitTerminationCommand await_termination = 9; } message ExplainCommand { @@ -268,6 +270,10 @@ message StreamingQueryCommand { // We can not do this right now since it base.proto imports this file. bool extended = 1; } + + message AwaitTerminationCommand { + optional int64 timeout_ms = 2; + } } // Response for commands on a streaming query. @@ -279,6 +285,8 @@ message StreamingQueryCommandResult { StatusResult status = 2; RecentProgressResult recent_progress = 3; ExplainResult explain = 4; + ExceptionResult exception = 5; + AwaitTerminationResult await_termination = 6; } message StatusResult { @@ -298,6 +306,15 @@ message StreamingQueryCommandResult { // Logical and physical plans as string string result = 1; } + + message ExceptionResult { + // Exception message as string + optional string exception_message = 1; + } + + message AwaitTerminationResult { + bool terminated = 1; + } } // Command to get the output of 'SparkContext.resources' diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 5f39fcd17f78..db6418c5ad10 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -52,7 +52,7 @@ import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter, UdfPacket} import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry -import org.apache.spark.sql.connect.service.SparkConnectStreamHandler +import org.apache.spark.sql.connect.service.{SparkConnectService, SparkConnectStreamHandler} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.arrow.ArrowConverters @@ -2255,6 +2255,23 @@ class SparkConnectPlanner(val session: SparkSession) { .build() respBuilder.setExplain(explain) + case StreamingQueryCommand.CommandCase.EXCEPTION => + val result = query.exception + result.foreach(e => + respBuilder.getExceptionBuilder + .setExceptionMessage(SparkConnectService.extractErrorMessage(e))) + + case StreamingQueryCommand.CommandCase.AWAIT_TERMINATION => + if (command.getAwaitTermination.hasTimeoutMs) { + val terminated = query.awaitTermination(command.getAwaitTermination.getTimeoutMs) + respBuilder.getAwaitTerminationBuilder + .setTerminated(terminated) + } else { + query.awaitTermination() + respBuilder.getAwaitTerminationBuilder + .setTerminated(true) + } + case StreamingQueryCommand.CommandCase.COMMAND_NOT_SET => throw new IllegalArgumentException("Missing command in StreamingQueryCommand") } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index 86c36bba7a08..86590569aaae 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -74,7 +74,6 @@ class SparkConnectService(debug: Boolean) } private def buildStatusFromThrowable(st: Throwable): RPCStatus = { - val message = StringUtils.abbreviate(st.getMessage, 2048) RPCStatus .newBuilder() .setCode(RPCCode.INTERNAL_VALUE) @@ -86,7 +85,7 @@ class SparkConnectService(debug: Boolean) .setDomain("org.apache.spark") .putMetadata("classes", compact(render(allClasses(st.getClass).map(_.getName)))) .build())) - .setMessage(if (message != null) message else "") + .setMessage(SparkConnectService.extractErrorMessage(st)) .build() } @@ -295,4 +294,13 @@ object SparkConnectService { } } } + + def extractErrorMessage(st: Throwable): String = { + val message = StringUtils.abbreviate(st.getMessage, 2048) + if (message != null) { + message + } else { + "" + } + } } diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py b/python/pyspark/sql/connect/proto/commands_pb2.py index 1cd0fad41eee..27de95a7aaa3 100644 --- a/python/pyspark/sql/connect/proto/commands_pb2.py +++ b/python/pyspark/sql/connect/proto/commands_pb2.py @@ -36,7 +36,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\x90\x06\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x18\x03 \x01(\x0b\x32).spark.connect.CreateDataFrameViewCommandH\x00R\x13\x63reateDataframeView\x12O\n\x12write_operation_v2\x18\x04 \x01(\x0b\x32\x1f.spark.connect.WriteOperationV2H\x00R\x10writeOperationV2\x12<\n\x0bsql_command\x18\x05 \x01(\x0b\x32\x19.spark.connect.SqlCommandH\x00R\nsqlCommand\x12k\n\x1cwrite_stream_operation_start\x18\x06 \x01(\x0b\x32(.spark.connect.WriteStreamOperationStartH\x00R\x19writeStreamOperationStart\x12^\n\x17streaming_query_command\x18\x07 \x01(\x0b\x32$.spark.connect.StreamingQueryCommandH\x00R\x15streamingQueryCommand\x12X\n\x15get_resources_command\x18\x08 \x01(\x0b\x32".spark.connect.GetResourcesCommandH\x00R\x13getResourcesCommand\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x0e\n\x0c\x63ommand_type"\xb3\x01\n\nSqlCommand\x12\x10\n\x03sql\x18\x01 \x01(\tR\x03sql\x12\x37\n\x04\x61rgs\x18\x02 \x03(\x0b\x32#.spark.connect.SqlCommand.ArgsEntryR\x04\x61rgs\x1aZ\n\tArgsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05value:\x02\x38\x01"\x96\x01\n\x1a\x43reateDataFrameViewCommand\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x1b\n\tis_global\x18\x03 \x01(\x08R\x08isGlobal\x12\x18\n\x07replace\x18\x04 \x01(\x08R\x07replace"\x9b\x08\n\x0eWriteOperation\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1b\n\x06source\x18\x02 \x01(\tH\x01R\x06source\x88\x01\x01\x12\x14\n\x04path\x18\x03 \x01(\tH\x00R\x04path\x12?\n\x05table\x18\x04 \x01(\x0b\x32\'.spark.connect.WriteOperation.SaveTableH\x00R\x05table\x12:\n\x04mode\x18\x05 \x01(\x0e\x32&.spark.connect.WriteOperation.SaveModeR\x04mode\x12*\n\x11sort_column_names\x18\x06 \x03(\tR\x0fsortColumnNames\x12\x31\n\x14partitioning_columns\x18\x07 \x03(\tR\x13partitioningColumns\x12\x43\n\tbucket_by\x18\x08 \x01(\x0b\x32&.spark.connect.WriteOperation.BucketByR\x08\x62ucketBy\x12\x44\n\x07options\x18\t \x03(\x0b\x32*.spark.connect.WriteOperation.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x82\x02\n\tSaveTable\x12\x1d\n\ntable_name\x18\x01 \x01(\tR\ttableName\x12X\n\x0bsave_method\x18\x02 \x01(\x0e\x32\x37.spark.connect.WriteOperation.SaveTable.TableSaveMethodR\nsaveMethod"|\n\x0fTableSaveMethod\x12!\n\x1dTABLE_SAVE_METHOD_UNSPECIFIED\x10\x00\x12#\n\x1fTABLE_SAVE_METHOD_SAVE_AS_TABLE\x10\x01\x12!\n\x1dTABLE_SAVE_METHOD_INSERT_INTO\x10\x02\x1a[\n\x08\x42ucketBy\x12.\n\x13\x62ucket_column_names\x18\x01 \x03(\tR\x11\x62ucketColumnNames\x12\x1f\n\x0bnum_buckets\x18\x02 \x01(\x05R\nnumBuckets"\x89\x01\n\x08SaveMode\x12\x19\n\x15SAVE_MODE_UNSPECIFIED\x10\x00\x12\x14\n\x10SAVE_MODE_APPEND\x10\x01\x12\x17\n\x13SAVE_MODE_OVERWRITE\x10\x02\x12\x1d\n\x19SAVE_MODE_ERROR_IF_EXISTS\x10\x03\x12\x14\n\x10SAVE_MODE_IGNORE\x10\x04\x42\x0b\n\tsave_typeB\t\n\x07_source"\xad\x06\n\x10WriteOperationV2\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1d\n\ntable_name\x18\x02 \x01(\tR\ttableName\x12\x1f\n\x08provider\x18\x03 \x01(\tH\x00R\x08provider\x88\x01\x01\x12L\n\x14partitioning_columns\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13partitioningColumns\x12\x46\n\x07options\x18\x05 \x03(\x0b\x32,.spark.connect.WriteOperationV2.OptionsEntryR\x07options\x12_\n\x10table_properties\x18\x06 \x03(\x0b\x32\x34.spark.connect.WriteOperationV2.TablePropertiesEntryR\x0ftableProperties\x12\x38\n\x04mode\x18\x07 \x01(\x0e\x32$.spark.connect.WriteOperationV2.ModeR\x04mode\x12J\n\x13overwrite_condition\x18\x08 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x12overwriteCondition\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"\x9f\x01\n\x04Mode\x12\x14\n\x10MODE_UNSPECIFIED\x10\x00\x12\x0f\n\x0bMODE_CREATE\x10\x01\x12\x12\n\x0eMODE_OVERWRITE\x10\x02\x12\x1d\n\x19MODE_OVERWRITE_PARTITIONS\x10\x03\x12\x0f\n\x0bMODE_APPEND\x10\x04\x12\x10\n\x0cMODE_REPLACE\x10\x05\x12\x1a\n\x16MODE_CREATE_OR_REPLACE\x10\x06\x42\x0b\n\t_provider"\x82\x05\n\x19WriteStreamOperationStart\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06\x66ormat\x18\x02 \x01(\tR\x06\x66ormat\x12O\n\x07options\x18\x03 \x03(\x0b\x32\x35.spark.connect.WriteStreamOperationStart.OptionsEntryR\x07options\x12:\n\x19partitioning_column_names\x18\x04 \x03(\tR\x17partitioningColumnNames\x12:\n\x18processing_time_interval\x18\x05 \x01(\tH\x00R\x16processingTimeInterval\x12%\n\ravailable_now\x18\x06 \x01(\x08H\x00R\x0c\x61vailableNow\x12\x14\n\x04once\x18\x07 \x01(\x08H\x00R\x04once\x12\x46\n\x1e\x63ontinuous_checkpoint_interval\x18\x08 \x01(\tH\x00R\x1c\x63ontinuousCheckpointInterval\x12\x1f\n\x0boutput_mode\x18\t \x01(\tR\noutputMode\x12\x1d\n\nquery_name\x18\n \x01(\tR\tqueryName\x12\x14\n\x04path\x18\x0b \x01(\tH\x01R\x04path\x12\x1f\n\ntable_name\x18\x0c \x01(\tH\x01R\ttableName\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07triggerB\x12\n\x10sink_destination"y\n\x1fWriteStreamOperationStartResult\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name"A\n\x18StreamingQueryInstanceId\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x15\n\x06run_id\x18\x02 \x01(\tR\x05runId"\x9d\x03\n\x15StreamingQueryCommand\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12\x18\n\x06status\x18\x02 \x01(\x08H\x00R\x06status\x12%\n\rlast_progress\x18\x03 \x01(\x08H\x00R\x0clastProgress\x12)\n\x0frecent_progress\x18\x04 \x01(\x08H\x00R\x0erecentProgress\x12\x14\n\x04stop\x18\x05 \x01(\x08H\x00R\x04stop\x12\x34\n\x15process_all_available\x18\x06 \x01(\x08H\x00R\x13processAllAvailable\x12O\n\x07\x65xplain\x18\x07 \x01(\x0b\x32\x33.spark.connect.StreamingQueryCommand.ExplainCommandH\x00R\x07\x65xplain\x1a,\n\x0e\x45xplainCommand\x12\x1a\n\x08\x65xtended\x18\x01 \x01(\x08R\x08\x65xtendedB\t\n\x07\x63ommand"\xa5\x05\n\x1bStreamingQueryCommandResult\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12Q\n\x06status\x18\x02 \x01(\x0b\x32\x37.spark.connect.StreamingQueryCommandResult.StatusResultH\x00R\x06status\x12j\n\x0frecent_progress\x18\x03 \x01(\x0b\x32?.spark.connect.StreamingQueryCommandResult.RecentProgressResultH\x00R\x0erecentProgress\x12T\n\x07\x65xplain\x18\x04 \x01(\x0b\x32\x38.spark.connect.StreamingQueryCommandResult.ExplainResultH\x00R\x07\x65xplain\x1a\xaa\x01\n\x0cStatusResult\x12%\n\x0estatus_message\x18\x01 \x01(\tR\rstatusMessage\x12*\n\x11is_data_available\x18\x02 \x01(\x08R\x0fisDataAvailable\x12*\n\x11is_trigger_active\x18\x03 \x01(\x08R\x0fisTriggerActive\x12\x1b\n\tis_active\x18\x04 \x01(\x08R\x08isActive\x1aH\n\x14RecentProgressResult\x12\x30\n\x14recent_progress_json\x18\x05 \x03(\tR\x12recentProgressJson\x1a\'\n\rExplainResult\x12\x16\n\x06result\x18\x01 \x01(\tR\x06resultB\r\n\x0bresult_type"\x15\n\x13GetResourcesCommand"\xd4\x01\n\x19GetResourcesCommandResult\x12U\n\tresources\x18\x01 \x03(\x0b\x32\x37.spark.connect.GetResourcesCommandResult.ResourcesEntryR\tresources\x1a`\n\x0eResourcesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x38\n\x05value\x18\x02 \x01(\x0b\x32".spark.connect.ResourceInformationR\x05value:\x02\x38\x01\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\x90\x06\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x18\x03 \x01(\x0b\x32).spark.connect.CreateDataFrameViewCommandH\x00R\x13\x63reateDataframeView\x12O\n\x12write_operation_v2\x18\x04 \x01(\x0b\x32\x1f.spark.connect.WriteOperationV2H\x00R\x10writeOperationV2\x12<\n\x0bsql_command\x18\x05 \x01(\x0b\x32\x19.spark.connect.SqlCommandH\x00R\nsqlCommand\x12k\n\x1cwrite_stream_operation_start\x18\x06 \x01(\x0b\x32(.spark.connect.WriteStreamOperationStartH\x00R\x19writeStreamOperationStart\x12^\n\x17streaming_query_command\x18\x07 \x01(\x0b\x32$.spark.connect.StreamingQueryCommandH\x00R\x15streamingQueryCommand\x12X\n\x15get_resources_command\x18\x08 \x01(\x0b\x32".spark.connect.GetResourcesCommandH\x00R\x13getResourcesCommand\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x0e\n\x0c\x63ommand_type"\xb3\x01\n\nSqlCommand\x12\x10\n\x03sql\x18\x01 \x01(\tR\x03sql\x12\x37\n\x04\x61rgs\x18\x02 \x03(\x0b\x32#.spark.connect.SqlCommand.ArgsEntryR\x04\x61rgs\x1aZ\n\tArgsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05value:\x02\x38\x01"\x96\x01\n\x1a\x43reateDataFrameViewCommand\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x1b\n\tis_global\x18\x03 \x01(\x08R\x08isGlobal\x12\x18\n\x07replace\x18\x04 \x01(\x08R\x07replace"\x9b\x08\n\x0eWriteOperation\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1b\n\x06source\x18\x02 \x01(\tH\x01R\x06source\x88\x01\x01\x12\x14\n\x04path\x18\x03 \x01(\tH\x00R\x04path\x12?\n\x05table\x18\x04 \x01(\x0b\x32\'.spark.connect.WriteOperation.SaveTableH\x00R\x05table\x12:\n\x04mode\x18\x05 \x01(\x0e\x32&.spark.connect.WriteOperation.SaveModeR\x04mode\x12*\n\x11sort_column_names\x18\x06 \x03(\tR\x0fsortColumnNames\x12\x31\n\x14partitioning_columns\x18\x07 \x03(\tR\x13partitioningColumns\x12\x43\n\tbucket_by\x18\x08 \x01(\x0b\x32&.spark.connect.WriteOperation.BucketByR\x08\x62ucketBy\x12\x44\n\x07options\x18\t \x03(\x0b\x32*.spark.connect.WriteOperation.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x82\x02\n\tSaveTable\x12\x1d\n\ntable_name\x18\x01 \x01(\tR\ttableName\x12X\n\x0bsave_method\x18\x02 \x01(\x0e\x32\x37.spark.connect.WriteOperation.SaveTable.TableSaveMethodR\nsaveMethod"|\n\x0fTableSaveMethod\x12!\n\x1dTABLE_SAVE_METHOD_UNSPECIFIED\x10\x00\x12#\n\x1fTABLE_SAVE_METHOD_SAVE_AS_TABLE\x10\x01\x12!\n\x1dTABLE_SAVE_METHOD_INSERT_INTO\x10\x02\x1a[\n\x08\x42ucketBy\x12.\n\x13\x62ucket_column_names\x18\x01 \x03(\tR\x11\x62ucketColumnNames\x12\x1f\n\x0bnum_buckets\x18\x02 \x01(\x05R\nnumBuckets"\x89\x01\n\x08SaveMode\x12\x19\n\x15SAVE_MODE_UNSPECIFIED\x10\x00\x12\x14\n\x10SAVE_MODE_APPEND\x10\x01\x12\x17\n\x13SAVE_MODE_OVERWRITE\x10\x02\x12\x1d\n\x19SAVE_MODE_ERROR_IF_EXISTS\x10\x03\x12\x14\n\x10SAVE_MODE_IGNORE\x10\x04\x42\x0b\n\tsave_typeB\t\n\x07_source"\xad\x06\n\x10WriteOperationV2\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1d\n\ntable_name\x18\x02 \x01(\tR\ttableName\x12\x1f\n\x08provider\x18\x03 \x01(\tH\x00R\x08provider\x88\x01\x01\x12L\n\x14partitioning_columns\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13partitioningColumns\x12\x46\n\x07options\x18\x05 \x03(\x0b\x32,.spark.connect.WriteOperationV2.OptionsEntryR\x07options\x12_\n\x10table_properties\x18\x06 \x03(\x0b\x32\x34.spark.connect.WriteOperationV2.TablePropertiesEntryR\x0ftableProperties\x12\x38\n\x04mode\x18\x07 \x01(\x0e\x32$.spark.connect.WriteOperationV2.ModeR\x04mode\x12J\n\x13overwrite_condition\x18\x08 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x12overwriteCondition\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"\x9f\x01\n\x04Mode\x12\x14\n\x10MODE_UNSPECIFIED\x10\x00\x12\x0f\n\x0bMODE_CREATE\x10\x01\x12\x12\n\x0eMODE_OVERWRITE\x10\x02\x12\x1d\n\x19MODE_OVERWRITE_PARTITIONS\x10\x03\x12\x0f\n\x0bMODE_APPEND\x10\x04\x12\x10\n\x0cMODE_REPLACE\x10\x05\x12\x1a\n\x16MODE_CREATE_OR_REPLACE\x10\x06\x42\x0b\n\t_provider"\x82\x05\n\x19WriteStreamOperationStart\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06\x66ormat\x18\x02 \x01(\tR\x06\x66ormat\x12O\n\x07options\x18\x03 \x03(\x0b\x32\x35.spark.connect.WriteStreamOperationStart.OptionsEntryR\x07options\x12:\n\x19partitioning_column_names\x18\x04 \x03(\tR\x17partitioningColumnNames\x12:\n\x18processing_time_interval\x18\x05 \x01(\tH\x00R\x16processingTimeInterval\x12%\n\ravailable_now\x18\x06 \x01(\x08H\x00R\x0c\x61vailableNow\x12\x14\n\x04once\x18\x07 \x01(\x08H\x00R\x04once\x12\x46\n\x1e\x63ontinuous_checkpoint_interval\x18\x08 \x01(\tH\x00R\x1c\x63ontinuousCheckpointInterval\x12\x1f\n\x0boutput_mode\x18\t \x01(\tR\noutputMode\x12\x1d\n\nquery_name\x18\n \x01(\tR\tqueryName\x12\x14\n\x04path\x18\x0b \x01(\tH\x01R\x04path\x12\x1f\n\ntable_name\x18\x0c \x01(\tH\x01R\ttableName\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07triggerB\x12\n\x10sink_destination"y\n\x1fWriteStreamOperationStartResult\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name"A\n\x18StreamingQueryInstanceId\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x15\n\x06run_id\x18\x02 \x01(\tR\x05runId"\xf8\x04\n\x15StreamingQueryCommand\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12\x18\n\x06status\x18\x02 \x01(\x08H\x00R\x06status\x12%\n\rlast_progress\x18\x03 \x01(\x08H\x00R\x0clastProgress\x12)\n\x0frecent_progress\x18\x04 \x01(\x08H\x00R\x0erecentProgress\x12\x14\n\x04stop\x18\x05 \x01(\x08H\x00R\x04stop\x12\x34\n\x15process_all_available\x18\x06 \x01(\x08H\x00R\x13processAllAvailable\x12O\n\x07\x65xplain\x18\x07 \x01(\x0b\x32\x33.spark.connect.StreamingQueryCommand.ExplainCommandH\x00R\x07\x65xplain\x12\x1e\n\texception\x18\x08 \x01(\x08H\x00R\texception\x12k\n\x11\x61wait_termination\x18\t \x01(\x0b\x32<.spark.connect.StreamingQueryCommand.AwaitTerminationCommandH\x00R\x10\x61waitTermination\x1a,\n\x0e\x45xplainCommand\x12\x1a\n\x08\x65xtended\x18\x01 \x01(\x08R\x08\x65xtended\x1aL\n\x17\x41waitTerminationCommand\x12"\n\ntimeout_ms\x18\x02 \x01(\x03H\x00R\ttimeoutMs\x88\x01\x01\x42\r\n\x0b_timeout_msB\t\n\x07\x63ommand"\x88\x08\n\x1bStreamingQueryCommandResult\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12Q\n\x06status\x18\x02 \x01(\x0b\x32\x37.spark.connect.StreamingQueryCommandResult.StatusResultH\x00R\x06status\x12j\n\x0frecent_progress\x18\x03 \x01(\x0b\x32?.spark.connect.StreamingQueryCommandResult.RecentProgressResultH\x00R\x0erecentProgress\x12T\n\x07\x65xplain\x18\x04 \x01(\x0b\x32\x38.spark.connect.StreamingQueryCommandResult.ExplainResultH\x00R\x07\x65xplain\x12Z\n\texception\x18\x05 \x01(\x0b\x32:.spark.connect.StreamingQueryCommandResult.ExceptionResultH\x00R\texception\x12p\n\x11\x61wait_termination\x18\x06 \x01(\x0b\x32\x41.spark.connect.StreamingQueryCommandResult.AwaitTerminationResultH\x00R\x10\x61waitTermination\x1a\xaa\x01\n\x0cStatusResult\x12%\n\x0estatus_message\x18\x01 \x01(\tR\rstatusMessage\x12*\n\x11is_data_available\x18\x02 \x01(\x08R\x0fisDataAvailable\x12*\n\x11is_trigger_active\x18\x03 \x01(\x08R\x0fisTriggerActive\x12\x1b\n\tis_active\x18\x04 \x01(\x08R\x08isActive\x1aH\n\x14RecentProgressResult\x12\x30\n\x14recent_progress_json\x18\x05 \x03(\tR\x12recentProgressJson\x1a\'\n\rExplainResult\x12\x16\n\x06result\x18\x01 \x01(\tR\x06result\x1aY\n\x0f\x45xceptionResult\x12\x30\n\x11\x65xception_message\x18\x01 \x01(\tH\x00R\x10\x65xceptionMessage\x88\x01\x01\x42\x14\n\x12_exception_message\x1a\x38\n\x16\x41waitTerminationResult\x12\x1e\n\nterminated\x18\x01 \x01(\x08R\nterminatedB\r\n\x0bresult_type"\x15\n\x13GetResourcesCommand"\xd4\x01\n\x19GetResourcesCommandResult\x12U\n\tresources\x18\x01 \x03(\x0b\x32\x37.spark.connect.GetResourcesCommandResult.ResourcesEntryR\tresources\x1a`\n\x0eResourcesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x38\n\x05value\x18\x02 \x01(\x0b\x32".spark.connect.ResourceInformationR\x05value:\x02\x38\x01\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) @@ -65,6 +65,9 @@ _STREAMINGQUERYCOMMAND_EXPLAINCOMMAND = _STREAMINGQUERYCOMMAND.nested_types_by_name[ "ExplainCommand" ] +_STREAMINGQUERYCOMMAND_AWAITTERMINATIONCOMMAND = _STREAMINGQUERYCOMMAND.nested_types_by_name[ + "AwaitTerminationCommand" +] _STREAMINGQUERYCOMMANDRESULT = DESCRIPTOR.message_types_by_name["StreamingQueryCommandResult"] _STREAMINGQUERYCOMMANDRESULT_STATUSRESULT = _STREAMINGQUERYCOMMANDRESULT.nested_types_by_name[ "StatusResult" @@ -75,6 +78,12 @@ _STREAMINGQUERYCOMMANDRESULT_EXPLAINRESULT = _STREAMINGQUERYCOMMANDRESULT.nested_types_by_name[ "ExplainResult" ] +_STREAMINGQUERYCOMMANDRESULT_EXCEPTIONRESULT = _STREAMINGQUERYCOMMANDRESULT.nested_types_by_name[ + "ExceptionResult" +] +_STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT = ( + _STREAMINGQUERYCOMMANDRESULT.nested_types_by_name["AwaitTerminationResult"] +) _GETRESOURCESCOMMAND = DESCRIPTOR.message_types_by_name["GetResourcesCommand"] _GETRESOURCESCOMMANDRESULT = DESCRIPTOR.message_types_by_name["GetResourcesCommandResult"] _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY = _GETRESOURCESCOMMANDRESULT.nested_types_by_name[ @@ -256,6 +265,15 @@ # @@protoc_insertion_point(class_scope:spark.connect.StreamingQueryCommand.ExplainCommand) }, ), + "AwaitTerminationCommand": _reflection.GeneratedProtocolMessageType( + "AwaitTerminationCommand", + (_message.Message,), + { + "DESCRIPTOR": _STREAMINGQUERYCOMMAND_AWAITTERMINATIONCOMMAND, + "__module__": "spark.connect.commands_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.StreamingQueryCommand.AwaitTerminationCommand) + }, + ), "DESCRIPTOR": _STREAMINGQUERYCOMMAND, "__module__": "spark.connect.commands_pb2" # @@protoc_insertion_point(class_scope:spark.connect.StreamingQueryCommand) @@ -263,6 +281,7 @@ ) _sym_db.RegisterMessage(StreamingQueryCommand) _sym_db.RegisterMessage(StreamingQueryCommand.ExplainCommand) +_sym_db.RegisterMessage(StreamingQueryCommand.AwaitTerminationCommand) StreamingQueryCommandResult = _reflection.GeneratedProtocolMessageType( "StreamingQueryCommandResult", @@ -295,6 +314,24 @@ # @@protoc_insertion_point(class_scope:spark.connect.StreamingQueryCommandResult.ExplainResult) }, ), + "ExceptionResult": _reflection.GeneratedProtocolMessageType( + "ExceptionResult", + (_message.Message,), + { + "DESCRIPTOR": _STREAMINGQUERYCOMMANDRESULT_EXCEPTIONRESULT, + "__module__": "spark.connect.commands_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.StreamingQueryCommandResult.ExceptionResult) + }, + ), + "AwaitTerminationResult": _reflection.GeneratedProtocolMessageType( + "AwaitTerminationResult", + (_message.Message,), + { + "DESCRIPTOR": _STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT, + "__module__": "spark.connect.commands_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.StreamingQueryCommandResult.AwaitTerminationResult) + }, + ), "DESCRIPTOR": _STREAMINGQUERYCOMMANDRESULT, "__module__": "spark.connect.commands_pb2" # @@protoc_insertion_point(class_scope:spark.connect.StreamingQueryCommandResult) @@ -304,6 +341,8 @@ _sym_db.RegisterMessage(StreamingQueryCommandResult.StatusResult) _sym_db.RegisterMessage(StreamingQueryCommandResult.RecentProgressResult) _sym_db.RegisterMessage(StreamingQueryCommandResult.ExplainResult) +_sym_db.RegisterMessage(StreamingQueryCommandResult.ExceptionResult) +_sym_db.RegisterMessage(StreamingQueryCommandResult.AwaitTerminationResult) GetResourcesCommand = _reflection.GeneratedProtocolMessageType( "GetResourcesCommand", @@ -390,21 +429,27 @@ _STREAMINGQUERYINSTANCEID._serialized_start = 3926 _STREAMINGQUERYINSTANCEID._serialized_end = 3991 _STREAMINGQUERYCOMMAND._serialized_start = 3994 - _STREAMINGQUERYCOMMAND._serialized_end = 4407 - _STREAMINGQUERYCOMMAND_EXPLAINCOMMAND._serialized_start = 4352 - _STREAMINGQUERYCOMMAND_EXPLAINCOMMAND._serialized_end = 4396 - _STREAMINGQUERYCOMMANDRESULT._serialized_start = 4410 - _STREAMINGQUERYCOMMANDRESULT._serialized_end = 5087 - _STREAMINGQUERYCOMMANDRESULT_STATUSRESULT._serialized_start = 4787 - _STREAMINGQUERYCOMMANDRESULT_STATUSRESULT._serialized_end = 4957 - _STREAMINGQUERYCOMMANDRESULT_RECENTPROGRESSRESULT._serialized_start = 4959 - _STREAMINGQUERYCOMMANDRESULT_RECENTPROGRESSRESULT._serialized_end = 5031 - _STREAMINGQUERYCOMMANDRESULT_EXPLAINRESULT._serialized_start = 5033 - _STREAMINGQUERYCOMMANDRESULT_EXPLAINRESULT._serialized_end = 5072 - _GETRESOURCESCOMMAND._serialized_start = 5089 - _GETRESOURCESCOMMAND._serialized_end = 5110 - _GETRESOURCESCOMMANDRESULT._serialized_start = 5113 - _GETRESOURCESCOMMANDRESULT._serialized_end = 5325 - _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_start = 5229 - _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_end = 5325 + _STREAMINGQUERYCOMMAND._serialized_end = 4626 + _STREAMINGQUERYCOMMAND_EXPLAINCOMMAND._serialized_start = 4493 + _STREAMINGQUERYCOMMAND_EXPLAINCOMMAND._serialized_end = 4537 + _STREAMINGQUERYCOMMAND_AWAITTERMINATIONCOMMAND._serialized_start = 4539 + _STREAMINGQUERYCOMMAND_AWAITTERMINATIONCOMMAND._serialized_end = 4615 + _STREAMINGQUERYCOMMANDRESULT._serialized_start = 4629 + _STREAMINGQUERYCOMMANDRESULT._serialized_end = 5661 + _STREAMINGQUERYCOMMANDRESULT_STATUSRESULT._serialized_start = 5212 + _STREAMINGQUERYCOMMANDRESULT_STATUSRESULT._serialized_end = 5382 + _STREAMINGQUERYCOMMANDRESULT_RECENTPROGRESSRESULT._serialized_start = 5384 + _STREAMINGQUERYCOMMANDRESULT_RECENTPROGRESSRESULT._serialized_end = 5456 + _STREAMINGQUERYCOMMANDRESULT_EXPLAINRESULT._serialized_start = 5458 + _STREAMINGQUERYCOMMANDRESULT_EXPLAINRESULT._serialized_end = 5497 + _STREAMINGQUERYCOMMANDRESULT_EXCEPTIONRESULT._serialized_start = 5499 + _STREAMINGQUERYCOMMANDRESULT_EXCEPTIONRESULT._serialized_end = 5588 + _STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT._serialized_start = 5590 + _STREAMINGQUERYCOMMANDRESULT_AWAITTERMINATIONRESULT._serialized_end = 5646 + _GETRESOURCESCOMMAND._serialized_start = 5663 + _GETRESOURCESCOMMAND._serialized_end = 5684 + _GETRESOURCESCOMMANDRESULT._serialized_start = 5687 + _GETRESOURCESCOMMANDRESULT._serialized_end = 5899 + _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_start = 5803 + _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_end = 5899 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi b/python/pyspark/sql/connect/proto/commands_pb2.pyi index 419a2bd840e3..972fe7503a1a 100644 --- a/python/pyspark/sql/connect/proto/commands_pb2.pyi +++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi @@ -876,6 +876,32 @@ class StreamingQueryCommand(google.protobuf.message.Message): self, field_name: typing_extensions.Literal["extended", b"extended"] ) -> None: ... + class AwaitTerminationCommand(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + TIMEOUT_MS_FIELD_NUMBER: builtins.int + timeout_ms: builtins.int + def __init__( + self, + *, + timeout_ms: builtins.int | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "_timeout_ms", b"_timeout_ms", "timeout_ms", b"timeout_ms" + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "_timeout_ms", b"_timeout_ms", "timeout_ms", b"timeout_ms" + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_timeout_ms", b"_timeout_ms"] + ) -> typing_extensions.Literal["timeout_ms"] | None: ... + QUERY_ID_FIELD_NUMBER: builtins.int STATUS_FIELD_NUMBER: builtins.int LAST_PROGRESS_FIELD_NUMBER: builtins.int @@ -883,6 +909,8 @@ class StreamingQueryCommand(google.protobuf.message.Message): STOP_FIELD_NUMBER: builtins.int PROCESS_ALL_AVAILABLE_FIELD_NUMBER: builtins.int EXPLAIN_FIELD_NUMBER: builtins.int + EXCEPTION_FIELD_NUMBER: builtins.int + AWAIT_TERMINATION_FIELD_NUMBER: builtins.int @property def query_id(self) -> global___StreamingQueryInstanceId: """(Required) Query instance. See `StreamingQueryInstanceId`.""" @@ -899,6 +927,11 @@ class StreamingQueryCommand(google.protobuf.message.Message): @property def explain(self) -> global___StreamingQueryCommand.ExplainCommand: """explain() API. Returns logical and physical plans.""" + exception: builtins.bool + """exception() API. Returns the exception in the query if any.""" + @property + def await_termination(self) -> global___StreamingQueryCommand.AwaitTerminationCommand: + """awaitTermination() API. Waits for the termination of the query.""" def __init__( self, *, @@ -909,12 +942,18 @@ class StreamingQueryCommand(google.protobuf.message.Message): stop: builtins.bool = ..., process_all_available: builtins.bool = ..., explain: global___StreamingQueryCommand.ExplainCommand | None = ..., + exception: builtins.bool = ..., + await_termination: global___StreamingQueryCommand.AwaitTerminationCommand | None = ..., ) -> None: ... def HasField( self, field_name: typing_extensions.Literal[ + "await_termination", + b"await_termination", "command", b"command", + "exception", + b"exception", "explain", b"explain", "last_progress", @@ -934,8 +973,12 @@ class StreamingQueryCommand(google.protobuf.message.Message): def ClearField( self, field_name: typing_extensions.Literal[ + "await_termination", + b"await_termination", "command", b"command", + "exception", + b"exception", "explain", b"explain", "last_progress", @@ -955,7 +998,14 @@ class StreamingQueryCommand(google.protobuf.message.Message): def WhichOneof( self, oneof_group: typing_extensions.Literal["command", b"command"] ) -> typing_extensions.Literal[ - "status", "last_progress", "recent_progress", "stop", "process_all_available", "explain" + "status", + "last_progress", + "recent_progress", + "stop", + "process_all_available", + "explain", + "exception", + "await_termination", ] | None: ... global___StreamingQueryCommand = StreamingQueryCommand @@ -1033,10 +1083,60 @@ class StreamingQueryCommandResult(google.protobuf.message.Message): self, field_name: typing_extensions.Literal["result", b"result"] ) -> None: ... + class ExceptionResult(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + EXCEPTION_MESSAGE_FIELD_NUMBER: builtins.int + exception_message: builtins.str + """Exception message as string""" + def __init__( + self, + *, + exception_message: builtins.str | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "_exception_message", + b"_exception_message", + "exception_message", + b"exception_message", + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "_exception_message", + b"_exception_message", + "exception_message", + b"exception_message", + ], + ) -> None: ... + def WhichOneof( + self, + oneof_group: typing_extensions.Literal["_exception_message", b"_exception_message"], + ) -> typing_extensions.Literal["exception_message"] | None: ... + + class AwaitTerminationResult(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + TERMINATED_FIELD_NUMBER: builtins.int + terminated: builtins.bool + def __init__( + self, + *, + terminated: builtins.bool = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["terminated", b"terminated"] + ) -> None: ... + QUERY_ID_FIELD_NUMBER: builtins.int STATUS_FIELD_NUMBER: builtins.int RECENT_PROGRESS_FIELD_NUMBER: builtins.int EXPLAIN_FIELD_NUMBER: builtins.int + EXCEPTION_FIELD_NUMBER: builtins.int + AWAIT_TERMINATION_FIELD_NUMBER: builtins.int @property def query_id(self) -> global___StreamingQueryInstanceId: """(Required) Query instance id. See `StreamingQueryInstanceId`.""" @@ -1046,6 +1146,10 @@ class StreamingQueryCommandResult(google.protobuf.message.Message): def recent_progress(self) -> global___StreamingQueryCommandResult.RecentProgressResult: ... @property def explain(self) -> global___StreamingQueryCommandResult.ExplainResult: ... + @property + def exception(self) -> global___StreamingQueryCommandResult.ExceptionResult: ... + @property + def await_termination(self) -> global___StreamingQueryCommandResult.AwaitTerminationResult: ... def __init__( self, *, @@ -1053,10 +1157,16 @@ class StreamingQueryCommandResult(google.protobuf.message.Message): status: global___StreamingQueryCommandResult.StatusResult | None = ..., recent_progress: global___StreamingQueryCommandResult.RecentProgressResult | None = ..., explain: global___StreamingQueryCommandResult.ExplainResult | None = ..., + exception: global___StreamingQueryCommandResult.ExceptionResult | None = ..., + await_termination: global___StreamingQueryCommandResult.AwaitTerminationResult | None = ..., ) -> None: ... def HasField( self, field_name: typing_extensions.Literal[ + "await_termination", + b"await_termination", + "exception", + b"exception", "explain", b"explain", "query_id", @@ -1072,6 +1182,10 @@ class StreamingQueryCommandResult(google.protobuf.message.Message): def ClearField( self, field_name: typing_extensions.Literal[ + "await_termination", + b"await_termination", + "exception", + b"exception", "explain", b"explain", "query_id", @@ -1086,7 +1200,9 @@ class StreamingQueryCommandResult(google.protobuf.message.Message): ) -> None: ... def WhichOneof( self, oneof_group: typing_extensions.Literal["result_type", b"result_type"] - ) -> typing_extensions.Literal["status", "recent_progress", "explain"] | None: ... + ) -> typing_extensions.Literal[ + "status", "recent_progress", "explain", "exception", "await_termination" + ] | None: ... global___StreamingQueryCommandResult = StreamingQueryCommandResult diff --git a/python/pyspark/sql/connect/streaming/query.py b/python/pyspark/sql/connect/streaming/query.py index aebab9fc69fd..a2b2e81357ef 100644 --- a/python/pyspark/sql/connect/streaming/query.py +++ b/python/pyspark/sql/connect/streaming/query.py @@ -24,6 +24,9 @@ from pyspark.sql.streaming.query import ( StreamingQuery as PySparkStreamingQuery, ) +from pyspark.errors.exceptions.connect import ( + StreamingQueryException as CapturedStreamingQueryException, +) __all__ = [ "StreamingQuery", # TODO(SPARK-43032): "StreamingQueryManager" @@ -66,11 +69,21 @@ def isActive(self) -> bool: isActive.__doc__ = PySparkStreamingQuery.isActive.__doc__ - # TODO (SPARK-42960): Implement and uncomment the doc def awaitTermination(self, timeout: Optional[int] = None) -> Optional[bool]: - raise NotImplementedError() + cmd = pb2.StreamingQueryCommand() + if timeout is not None: + if not isinstance(timeout, (int, float)) or timeout <= 0: + raise ValueError("timeout must be a positive integer or float. Got %s" % timeout) + cmd.await_termination.timeout_ms = int(timeout * 1000) + terminated = self._execute_streaming_query_cmd(cmd).await_termination.terminated + return terminated + else: + await_termination_cmd = pb2.StreamingQueryCommand.AwaitTerminationCommand() + cmd.await_termination.CopyFrom(await_termination_cmd) + self._execute_streaming_query_cmd(cmd) + return None - # awaitTermination.__doc__ = PySparkStreamingQuery.awaitTermination.__doc__ + awaitTermination.__doc__ = PySparkStreamingQuery.awaitTermination.__doc__ @property def status(self) -> Dict[str, Any]: @@ -127,9 +140,14 @@ def explain(self, extended: bool = False) -> None: explain.__doc__ = PySparkStreamingQuery.explain.__doc__ - # TODO (SPARK-42960): Implement and uncomment the doc def exception(self) -> Optional[StreamingQueryException]: - raise NotImplementedError() + cmd = pb2.StreamingQueryCommand() + cmd.exception = True + exception = self._execute_streaming_query_cmd(cmd).exception + if exception.HasField("exception_message"): + return CapturedStreamingQueryException(exception.exception_message) + else: + return None exception.__doc__ = PySparkStreamingQuery.exception.__doc__ diff --git a/python/pyspark/sql/streaming/query.py b/python/pyspark/sql/streaming/query.py index b902f0514fce..aada8265dd2d 100644 --- a/python/pyspark/sql/streaming/query.py +++ b/python/pyspark/sql/streaming/query.py @@ -196,7 +196,7 @@ def awaitTermination(self, timeout: Optional[int] = None) -> Optional[bool]: >>> sq.stop() """ if timeout is not None: - if not isinstance(timeout, (int, float)) or timeout < 0: + if not isinstance(timeout, (int, float)) or timeout <= 0: raise ValueError("timeout must be a positive integer or float. Got %s" % timeout) return self._jsq.awaitTermination(int(timeout * 1000)) else: diff --git a/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py b/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py index fc4b251c5c61..9419194a6e35 100644 --- a/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py +++ b/python/pyspark/sql/tests/connect/streaming/test_parity_streaming.py @@ -22,11 +22,14 @@ class StreamingParityTests(StreamingTestsMixin, ReusedConnectTestCase): - @unittest.skip("Will be supported with SPARK-42960.") + @unittest.skip("Query manager API will be supported later with SPARK-43032.") def test_stream_await_termination(self): super().test_stream_await_termination() - @unittest.skip("Will be supported with SPARK-42960.") + @unittest.skip( + "Query immediately quits after throw, " + + "allowing access to supported queries will be added in SPARK-42962." + ) def test_stream_exception(self): super().test_stream_exception() diff --git a/python/pyspark/sql/tests/streaming/test_streaming.py b/python/pyspark/sql/tests/streaming/test_streaming.py index 838d413a0cc3..52fa19a86420 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming.py +++ b/python/pyspark/sql/tests/streaming/test_streaming.py @@ -24,6 +24,7 @@ from pyspark.sql.functions import lit from pyspark.sql.types import StructType, StructField, IntegerType, StringType from pyspark.testing.sqlutils import ReusedSQLTestCase +from pyspark.errors.exceptions.connect import SparkConnectException class StreamingTestsMixin: @@ -39,9 +40,8 @@ def test_streaming_query_functions_basic(self): self.assertTrue(isinstance(query.id, str)) self.assertTrue(isinstance(query.runId, str)) self.assertTrue(query.isActive) - # TODO: Will be uncommented with [SPARK-42960] - # self.assertEqual(query.exception(), None) - # self.assertFalse(query.awaitTermination(1)) + self.assertEqual(query.exception(), None) + self.assertFalse(query.awaitTermination(1)) query.processAllAvailable() recentProgress = query.recentProgress lastProgress = query.lastProgress @@ -253,8 +253,13 @@ def test_stream_await_termination(self): duration = time.time() - now self.assertTrue(duration >= 2) self.assertFalse(res) - finally: + q.processAllAvailable() + q.stop() + # Sanity check when no parameter is set + q.awaitTermination() + self.assertFalse(q.isActive) + finally: q.stop() shutil.rmtree(tmpPath) @@ -285,11 +290,24 @@ def test_stream_exception(self): # This is expected self._assert_exception_tree_contains_msg(e, "ZeroDivisionError") finally: + exception = sq.exception() sq.stop() - self.assertIsInstance(sq.exception(), StreamingQueryException) - self._assert_exception_tree_contains_msg(sq.exception(), "ZeroDivisionError") + self.assertIsInstance(exception, StreamingQueryException) + self._assert_exception_tree_contains_msg(exception, "ZeroDivisionError") def _assert_exception_tree_contains_msg(self, exception, msg): + if isinstance(exception, SparkConnectException): + self._assert_exception_tree_contains_msg_connect(exception, msg) + else: + self._assert_exception_tree_contains_msg_default(exception, msg) + + def _assert_exception_tree_contains_msg_connect(self, exception, msg): + self.assertTrue( + msg in exception.message, + "Exception tree doesn't contain the expected message: %s" % msg, + ) + + def _assert_exception_tree_contains_msg_default(self, exception, msg): e = exception contains = msg in e.desc while e.cause is not None and not contains: