diff --git a/python/pyspark/sql/streaming/listener.py b/python/pyspark/sql/streaming/listener.py index 33482664a7b02..198af0c9cbeb5 100644 --- a/python/pyspark/sql/streaming/listener.py +++ b/python/pyspark/sql/streaming/listener.py @@ -15,7 +15,8 @@ # limitations under the License. # import uuid -from typing import Optional, Dict, List +import json +from typing import Any, Dict, List, Optional from abc import ABC, abstractmethod from py4j.java_gateway import JavaObject @@ -129,16 +130,16 @@ def __init__(self, pylistener: StreamingQueryListener) -> None: self.pylistener = pylistener def onQueryStarted(self, jevent: JavaObject) -> None: - self.pylistener.onQueryStarted(QueryStartedEvent(jevent)) + self.pylistener.onQueryStarted(QueryStartedEvent.fromJObject(jevent)) def onQueryProgress(self, jevent: JavaObject) -> None: - self.pylistener.onQueryProgress(QueryProgressEvent(jevent)) + self.pylistener.onQueryProgress(QueryProgressEvent.fromJObject(jevent)) def onQueryIdle(self, jevent: JavaObject) -> None: - self.pylistener.onQueryIdle(QueryIdleEvent(jevent)) + self.pylistener.onQueryIdle(QueryIdleEvent.fromJObject(jevent)) def onQueryTerminated(self, jevent: JavaObject) -> None: - self.pylistener.onQueryTerminated(QueryTerminatedEvent(jevent)) + self.pylistener.onQueryTerminated(QueryTerminatedEvent.fromJObject(jevent)) class Java: implements = ["org.apache.spark.sql.streaming.PythonStreamingQueryListener"] @@ -155,11 +156,31 @@ class QueryStartedEvent: This API is evolving. """ - def __init__(self, jevent: JavaObject) -> None: - self._id: uuid.UUID = uuid.UUID(jevent.id().toString()) - self._runId: uuid.UUID = uuid.UUID(jevent.runId().toString()) - self._name: Optional[str] = jevent.name() - self._timestamp: str = jevent.timestamp() + def __init__( + self, id: uuid.UUID, runId: uuid.UUID, name: Optional[str], timestamp: str + ) -> None: + self._id: uuid.UUID = id + self._runId: uuid.UUID = runId + self._name: Optional[str] = name + self._timestamp: str = timestamp + + @classmethod + def fromJObject(cls, jevent: JavaObject) -> "QueryStartedEvent": + return cls( + id=uuid.UUID(jevent.id().toString()), + runId=uuid.UUID(jevent.runId().toString()), + name=jevent.name(), + timestamp=jevent.timestamp(), + ) + + @classmethod + def fromJson(cls, j: Dict[str, Any]) -> "QueryStartedEvent": + return cls( + id=uuid.UUID(j["id"]), + runId=uuid.UUID(j["runId"]), + name=j["name"], + timestamp=j["timestamp"], + ) @property def id(self) -> uuid.UUID: @@ -203,8 +224,16 @@ class QueryProgressEvent: This API is evolving. """ - def __init__(self, jevent: JavaObject) -> None: - self._progress: StreamingQueryProgress = StreamingQueryProgress(jevent.progress()) + def __init__(self, progress: "StreamingQueryProgress") -> None: + self._progress: StreamingQueryProgress = progress + + @classmethod + def fromJObject(cls, jevent: JavaObject) -> "QueryProgressEvent": + return cls(progress=StreamingQueryProgress.fromJObject(jevent.progress())) + + @classmethod + def fromJson(cls, j: Dict[str, Any]) -> "QueryProgressEvent": + return cls(progress=StreamingQueryProgress.fromJson(j["progress"])) @property def progress(self) -> "StreamingQueryProgress": @@ -225,10 +254,22 @@ class QueryIdleEvent: This API is evolving. """ - def __init__(self, jevent: JavaObject) -> None: - self._id: uuid.UUID = uuid.UUID(jevent.id().toString()) - self._runId: uuid.UUID = uuid.UUID(jevent.runId().toString()) - self._timestamp: str = jevent.timestamp() + def __init__(self, id: uuid.UUID, runId: uuid.UUID, timestamp: str) -> None: + self._id: uuid.UUID = id + self._runId: uuid.UUID = runId + self._timestamp: str = timestamp + + @classmethod + def fromJObject(cls, jevent: JavaObject) -> "QueryIdleEvent": + return cls( + id=uuid.UUID(jevent.id().toString()), + runId=uuid.UUID(jevent.runId().toString()), + timestamp=jevent.timestamp(), + ) + + @classmethod + def fromJson(cls, j: Dict[str, Any]) -> "QueryIdleEvent": + return cls(id=uuid.UUID(j["id"]), runId=uuid.UUID(j["runId"]), timestamp=j["timestamp"]) @property def id(self) -> uuid.UUID: @@ -265,14 +306,36 @@ class QueryTerminatedEvent: This API is evolving. """ - def __init__(self, jevent: JavaObject) -> None: - self._id: uuid.UUID = uuid.UUID(jevent.id().toString()) - self._runId: uuid.UUID = uuid.UUID(jevent.runId().toString()) + def __init__( + self, + id: uuid.UUID, + runId: uuid.UUID, + exception: Optional[str], + errorClassOnException: Optional[str], + ) -> None: + self._id: uuid.UUID = id + self._runId: uuid.UUID = runId + self._exception: Optional[str] = exception + self._errorClassOnException: Optional[str] = errorClassOnException + + @classmethod + def fromJObject(cls, jevent: JavaObject) -> "QueryTerminatedEvent": jexception = jevent.exception() - self._exception: Optional[str] = jexception.get() if jexception.isDefined() else None jerrorclass = jevent.errorClassOnException() - self._errorClassOnException: Optional[str] = ( - jerrorclass.get() if jerrorclass.isDefined() else None + return cls( + id=uuid.UUID(jevent.id().toString()), + runId=uuid.UUID(jevent.runId().toString()), + exception=jexception.get() if jexception.isDefined() else None, + errorClassOnException=jerrorclass.get() if jerrorclass.isDefined() else None, + ) + + @classmethod + def fromJson(cls, j: Dict[str, Any]) -> "QueryTerminatedEvent": + return cls( + id=uuid.UUID(j["id"]), + runId=uuid.UUID(j["runId"]), + exception=j["exception"], + errorClassOnException=j["errorClassOnException"], ) @property @@ -322,32 +385,97 @@ class StreamingQueryProgress: This API is evolving. """ - def __init__(self, jprogress: JavaObject) -> None: + def __init__( + self, + id: uuid.UUID, + runId: uuid.UUID, + name: Optional[str], + timestamp: str, + batchId: int, + batchDuration: int, + durationMs: Dict[str, int], + eventTime: Dict[str, str], + stateOperators: List["StateOperatorProgress"], + sources: List["SourceProgress"], + sink: "SinkProgress", + numInputRows: int, + inputRowsPerSecond: float, + processedRowsPerSecond: float, + observedMetrics: Dict[str, Row], + jprogress: Optional[JavaObject] = None, + jdict: Optional[Dict[str, Any]] = None, + ): + self._jprogress: Optional[JavaObject] = jprogress + self._jdict: Optional[Dict[str, Any]] = jdict + self._id: uuid.UUID = id + self._runId: uuid.UUID = runId + self._name: Optional[str] = name + self._timestamp: str = timestamp + self._batchId: int = batchId + self._batchDuration: int = batchDuration + self._durationMs: Dict[str, int] = durationMs + self._eventTime: Dict[str, str] = eventTime + self._stateOperators: List[StateOperatorProgress] = stateOperators + self._sources: List[SourceProgress] = sources + self._sink: SinkProgress = sink + self._numInputRows: int = numInputRows + self._inputRowsPerSecond: float = inputRowsPerSecond + self._processedRowsPerSecond: float = processedRowsPerSecond + self._observedMetrics: Dict[str, Row] = observedMetrics + + @classmethod + def fromJObject(cls, jprogress: JavaObject) -> "StreamingQueryProgress": from pyspark import SparkContext - self._jprogress: JavaObject = jprogress - self._id: uuid.UUID = uuid.UUID(jprogress.id().toString()) - self._runId: uuid.UUID = uuid.UUID(jprogress.runId().toString()) - self._name: Optional[str] = jprogress.name() - self._timestamp: str = jprogress.timestamp() - self._batchId: int = jprogress.batchId() - self._inputRowsPerSecond: float = jprogress.inputRowsPerSecond() - self._processedRowsPerSecond: float = jprogress.processedRowsPerSecond() - self._batchDuration: int = jprogress.batchDuration() - self._durationMs: Dict[str, int] = dict(jprogress.durationMs()) - self._eventTime: Dict[str, str] = dict(jprogress.eventTime()) - self._stateOperators: List[StateOperatorProgress] = [ - StateOperatorProgress(js) for js in jprogress.stateOperators() - ] - self._sources: List[SourceProgress] = [SourceProgress(js) for js in jprogress.sources()] - self._sink: SinkProgress = SinkProgress(jprogress.sink()) - - self._observedMetrics: Dict[str, Row] = { - k: cloudpickle.loads( - SparkContext._jvm.PythonSQLUtils.toPyRow(jr) # type: ignore[union-attr] - ) - for k, jr in dict(jprogress.observedMetrics()).items() - } + return cls( + jprogress=jprogress, + id=uuid.UUID(jprogress.id().toString()), + runId=uuid.UUID(jprogress.runId().toString()), + name=jprogress.name(), + timestamp=jprogress.timestamp(), + batchId=jprogress.batchId(), + batchDuration=jprogress.batchDuration(), + durationMs=dict(jprogress.durationMs()), + eventTime=dict(jprogress.eventTime()), + stateOperators=[ + StateOperatorProgress.fromJObject(js) for js in jprogress.stateOperators() + ], + sources=[SourceProgress.fromJObject(js) for js in jprogress.sources()], + sink=SinkProgress.fromJObject(jprogress.sink()), + numInputRows=jprogress.numInputRows(), + inputRowsPerSecond=jprogress.inputRowsPerSecond(), + processedRowsPerSecond=jprogress.processedRowsPerSecond(), + observedMetrics={ + k: cloudpickle.loads( + SparkContext._jvm.PythonSQLUtils.toPyRow(jr) # type: ignore[union-attr] + ) + for k, jr in dict(jprogress.observedMetrics()).items() + }, + ) + + @classmethod + def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress": + return cls( + jdict=j, + id=uuid.UUID(j["id"]), + runId=uuid.UUID(j["runId"]), + name=j["name"], + timestamp=j["timestamp"], + batchId=j["batchId"], + batchDuration=j["batchDuration"], + durationMs=dict(j["durationMs"]), + eventTime=dict(j["eventTime"]), + stateOperators=[StateOperatorProgress.fromJson(s) for s in j["stateOperators"]], + sources=[SourceProgress.fromJson(s) for s in j["sources"]], + sink=SinkProgress.fromJson(j["sink"]), + numInputRows=j["numInputRows"], + inputRowsPerSecond=j["inputRowsPerSecond"], + processedRowsPerSecond=j["processedRowsPerSecond"], + observedMetrics={ + k: Row(*row_dict.keys())(*row_dict.values()) # Assume no nested rows + for k, row_dict in j["observedMetrics"].items() + }, + ) @property def id(self) -> uuid.UUID: @@ -448,11 +576,11 @@ def observedMetrics(self) -> Dict[str, Row]: return self._observedMetrics @property - def numInputRows(self) -> Optional[str]: + def numInputRows(self) -> int: """ The aggregate (across all sources) number of records processed in a trigger. """ - return self._jprogress.numInputRows() + return self._numInputRows @property def inputRowsPerSecond(self) -> float: @@ -464,7 +592,7 @@ def inputRowsPerSecond(self) -> float: @property def processedRowsPerSecond(self) -> float: """ - The aggregate (across all sources) rate at which Spark is processing data.. + The aggregate (across all sources) rate at which Spark is processing data. """ return self._processedRowsPerSecond @@ -473,14 +601,22 @@ def json(self) -> str: """ The compact JSON representation of this progress. """ - return self._jprogress.json() + assert self._jdict is not None or self._jprogress is not None + if self._jprogress: + return self._jprogress.json() + else: + return json.dumps(self._jdict) @property def prettyJson(self) -> str: """ The pretty (i.e. indented) JSON representation of this progress. """ - return self._jprogress.prettyJson() + assert self._jdict is not None or self._jprogress is not None + if self._jprogress: + return self._jprogress.prettyJson() + else: + return json.dumps(self._jdict, indent=4) def __str__(self) -> str: return self.prettyJson @@ -495,20 +631,73 @@ class StateOperatorProgress: This API is evolving. """ - def __init__(self, jprogress: JavaObject) -> None: - self._jprogress: JavaObject = jprogress - self._operatorName: str = jprogress.operatorName() - self._numRowsTotal: int = jprogress.numRowsTotal() - self._numRowsUpdated: int = jprogress.numRowsUpdated() - self._allUpdatesTimeMs: int = jprogress.allUpdatesTimeMs() - self._numRowsRemoved: int = jprogress.numRowsRemoved() - self._allRemovalsTimeMs: int = jprogress.allRemovalsTimeMs() - self._commitTimeMs: int = jprogress.commitTimeMs() - self._memoryUsedBytes: int = jprogress.memoryUsedBytes() - self._numRowsDroppedByWatermark: int = jprogress.numRowsDroppedByWatermark() - self._numShufflePartitions: int = jprogress.numShufflePartitions() - self._numStateStoreInstances: int = jprogress.numStateStoreInstances() - self._customMetrics: Dict[str, int] = dict(jprogress.customMetrics()) + def __init__( + self, + operatorName: str, + numRowsTotal: int, + numRowsUpdated: int, + numRowsRemoved: int, + allUpdatesTimeMs: int, + allRemovalsTimeMs: int, + commitTimeMs: int, + memoryUsedBytes: int, + numRowsDroppedByWatermark: int, + numShufflePartitions: int, + numStateStoreInstances: int, + customMetrics: Dict[str, int], + jprogress: Optional[JavaObject] = None, + jdict: Optional[Dict[str, Any]] = None, + ): + self._jprogress: Optional[JavaObject] = jprogress + self._jdict: Optional[Dict[str, Any]] = jdict + self._operatorName: str = operatorName + self._numRowsTotal: int = numRowsTotal + self._numRowsUpdated: int = numRowsUpdated + self._numRowsRemoved: int = numRowsRemoved + self._allUpdatesTimeMs: int = allUpdatesTimeMs + self._allRemovalsTimeMs: int = allRemovalsTimeMs + self._commitTimeMs: int = commitTimeMs + self._memoryUsedBytes: int = memoryUsedBytes + self._numRowsDroppedByWatermark: int = numRowsDroppedByWatermark + self._numShufflePartitions: int = numShufflePartitions + self._numStateStoreInstances: int = numStateStoreInstances + self._customMetrics: Dict[str, int] = customMetrics + + @classmethod + def fromJObject(cls, jprogress: JavaObject) -> "StateOperatorProgress": + return cls( + jprogress=jprogress, + operatorName=jprogress.operatorName(), + numRowsTotal=jprogress.numRowsTotal(), + numRowsUpdated=jprogress.numRowsUpdated(), + allUpdatesTimeMs=jprogress.allUpdatesTimeMs(), + numRowsRemoved=jprogress.numRowsRemoved(), + allRemovalsTimeMs=jprogress.allRemovalsTimeMs(), + commitTimeMs=jprogress.commitTimeMs(), + memoryUsedBytes=jprogress.memoryUsedBytes(), + numRowsDroppedByWatermark=jprogress.numRowsDroppedByWatermark(), + numShufflePartitions=jprogress.numShufflePartitions(), + numStateStoreInstances=jprogress.numStateStoreInstances(), + customMetrics=dict(jprogress.customMetrics()), + ) + + @classmethod + def fromJson(cls, j: Dict[str, Any]) -> "StateOperatorProgress": + return cls( + jdict=j, + operatorName=j["operatorName"], + numRowsTotal=j["numRowsTotal"], + numRowsUpdated=j["numRowsUpdated"], + numRowsRemoved=j["numRowsRemoved"], + allUpdatesTimeMs=j["allUpdatesTimeMs"], + allRemovalsTimeMs=j["allRemovalsTimeMs"], + commitTimeMs=j["commitTimeMs"], + memoryUsedBytes=j["memoryUsedBytes"], + numRowsDroppedByWatermark=j["numRowsDroppedByWatermark"], + numShufflePartitions=j["numShufflePartitions"], + numStateStoreInstances=j["numStateStoreInstances"], + customMetrics=dict(j["customMetrics"]), + ) @property def operatorName(self) -> str: @@ -563,14 +752,22 @@ def json(self) -> str: """ The compact JSON representation of this progress. """ - return self._jprogress.json() + assert self._jdict is not None or self._jprogress is not None + if self._jprogress: + return self._jprogress.json() + else: + return json.dumps(self._jdict) @property def prettyJson(self) -> str: """ The pretty (i.e. indented) JSON representation of this progress. """ - return self._jprogress.prettyJson() + assert self._jdict is not None or self._jprogress is not None + if self._jprogress: + return self._jprogress.prettyJson() + else: + return json.dumps(self._jdict, indent=4) def __str__(self) -> str: return self.prettyJson @@ -585,16 +782,57 @@ class SourceProgress: This API is evolving. """ - def __init__(self, jprogress: JavaObject) -> None: - self._jprogress: JavaObject = jprogress - self._description: str = jprogress.description() - self._startOffset: str = jprogress.startOffset() - self._endOffset: str = jprogress.endOffset() - self._latestOffset: str = jprogress.latestOffset() - self._numInputRows: int = jprogress.numInputRows() - self._inputRowsPerSecond: float = jprogress.inputRowsPerSecond() - self._processedRowsPerSecond: float = jprogress.processedRowsPerSecond() - self._metrics: Dict[str, str] = dict(jprogress.metrics()) + def __init__( + self, + description: str, + startOffset: str, + endOffset: str, + latestOffset: str, + numInputRows: int, + inputRowsPerSecond: float, + processedRowsPerSecond: float, + metrics: Dict[str, str], + jprogress: Optional[JavaObject] = None, + jdict: Optional[Dict[str, Any]] = None, + ) -> None: + self._jprogress: Optional[JavaObject] = jprogress + self._jdict: Optional[Dict[str, Any]] = jdict + self._description: str = description + self._startOffset: str = startOffset + self._endOffset: str = endOffset + self._latestOffset: str = latestOffset + self._numInputRows: int = numInputRows + self._inputRowsPerSecond: float = inputRowsPerSecond + self._processedRowsPerSecond: float = processedRowsPerSecond + self._metrics: Dict[str, str] = metrics + + @classmethod + def fromJObject(cls, jprogress: JavaObject) -> "SourceProgress": + return cls( + jprogress=jprogress, + description=jprogress.description(), + startOffset=str(jprogress.startOffset()), + endOffset=str(jprogress.endOffset()), + latestOffset=str(jprogress.latestOffset()), + numInputRows=jprogress.numInputRows(), + inputRowsPerSecond=jprogress.inputRowsPerSecond(), + processedRowsPerSecond=jprogress.processedRowsPerSecond(), + metrics=dict(jprogress.metrics()), + ) + + @classmethod + def fromJson(cls, j: Dict[str, Any]) -> "SourceProgress": + return cls( + jdict=j, + description=j["description"], + startOffset=str(j["startOffset"]), + endOffset=str(j["endOffset"]), + latestOffset=str(j["latestOffset"]), + numInputRows=j["numInputRows"], + inputRowsPerSecond=j["inputRowsPerSecond"], + processedRowsPerSecond=j["processedRowsPerSecond"], + metrics=dict(j["metrics"]), + ) @property def description(self) -> str: @@ -654,14 +892,22 @@ def json(self) -> str: """ The compact JSON representation of this progress. """ - return self._jprogress.json() + assert self._jdict is not None or self._jprogress is not None + if self._jprogress: + return self._jprogress.json() + else: + return json.dumps(self._jdict) @property def prettyJson(self) -> str: """ The pretty (i.e. indented) JSON representation of this progress. """ - return self._jprogress.prettyJson() + assert self._jdict is not None or self._jprogress is not None + if self._jprogress: + return self._jprogress.prettyJson() + else: + return json.dumps(self._jdict, indent=4) def __str__(self) -> str: return self.prettyJson @@ -676,11 +922,37 @@ class SinkProgress: This API is evolving. """ - def __init__(self, jprogress: JavaObject) -> None: - self._jprogress: JavaObject = jprogress - self._description: str = jprogress.description() - self._numOutputRows: int = jprogress.numOutputRows() - self._metrics: Dict[str, str] = dict(jprogress.metrics()) + def __init__( + self, + description: str, + numOutputRows: int, + metrics: Dict[str, str], + jprogress: Optional[JavaObject] = None, + jdict: Optional[Dict[str, Any]] = None, + ) -> None: + self._jprogress: Optional[JavaObject] = jprogress + self._jdict: Optional[Dict[str, Any]] = jdict + self._description: str = description + self._numOutputRows: int = numOutputRows + self._metrics: Dict[str, str] = metrics + + @classmethod + def fromJObject(cls, jprogress: JavaObject) -> "SinkProgress": + return cls( + jprogress=jprogress, + description=jprogress.description(), + numOutputRows=jprogress.numOutputRows(), + metrics=dict(jprogress.metrics()), + ) + + @classmethod + def fromJson(cls, j: Dict[str, Any]) -> "SinkProgress": + return cls( + jdict=j, + description=j["description"], + numOutputRows=j["numOutputRows"], + metrics=j["metrics"], + ) @property def description(self) -> str: @@ -706,14 +978,22 @@ def json(self) -> str: """ The compact JSON representation of this progress. """ - return self._jprogress.json() + assert self._jdict is not None or self._jprogress is not None + if self._jprogress: + return self._jprogress.json() + else: + return json.dumps(self._jdict) @property def prettyJson(self) -> str: """ The pretty (i.e. indented) JSON representation of this progress. """ - return self._jprogress.prettyJson() + assert self._jdict is not None or self._jprogress is not None + if self._jprogress: + return self._jprogress.prettyJson() + else: + return json.dumps(self._jdict, indent=4) def __str__(self) -> str: return self.prettyJson diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index 71d76bc4e8d52..2bd6d2c666837 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -19,6 +19,7 @@ import uuid from datetime import datetime +from pyspark import Row from pyspark.sql.streaming import StreamingQueryListener from pyspark.sql.streaming.listener import ( QueryStartedEvent, @@ -51,21 +52,21 @@ def get_number_of_public_methods(clz): get_number_of_public_methods( "org.apache.spark.sql.streaming.StreamingQueryListener$QueryStartedEvent" ), - 14, + 15, msg, ) self.assertEquals( get_number_of_public_methods( "org.apache.spark.sql.streaming.StreamingQueryListener$QueryProgressEvent" ), - 11, + 12, msg, ) self.assertEquals( get_number_of_public_methods( "org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminatedEvent" ), - 14, + 15, msg, ) self.assertEquals( @@ -112,7 +113,15 @@ def onQueryTerminated(self, event): self.spark.streams.addListener(test_listener) df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() - q = df.writeStream.format("noop").queryName("test").start() + + # check successful stateful query + df_stateful = df.groupBy().count() # make query stateful + q = ( + df_stateful.writeStream.format("noop") + .queryName("test") + .outputMode("complete") + .start() + ) self.assertTrue(q.isActive) time.sleep(10) q.stop() @@ -123,6 +132,17 @@ def onQueryTerminated(self, event): self.check_start_event(start_event) self.check_progress_event(progress_event) self.check_terminated_event(terminated_event) + + # Check query terminated with exception + from pyspark.sql.functions import col, udf + + bad_udf = udf(lambda x: 1 / 0) + q = df.select(bad_udf(col("value"))).writeStream.format("noop").start() + time.sleep(5) + q.stop() + self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty() + self.check_terminated_event(terminated_event, "ZeroDivisionError") + finally: self.spark.streams.removeListener(test_listener) @@ -131,7 +151,7 @@ def check_start_event(self, event): self.assertTrue(isinstance(event, QueryStartedEvent)) self.assertTrue(isinstance(event.id, uuid.UUID)) self.assertTrue(isinstance(event.runId, uuid.UUID)) - self.assertEquals(event.name, "test") + self.assertTrue(event.name is None or event.name == "test") try: datetime.strptime(event.timestamp, "%Y-%m-%dT%H:%M:%S.%fZ") except ValueError: @@ -142,14 +162,20 @@ def check_progress_event(self, event): self.assertTrue(isinstance(event, QueryProgressEvent)) self.check_streaming_query_progress(event.progress) - def check_terminated_event(self, event): + def check_terminated_event(self, event, exception=None, error_class=None): """Check QueryTerminatedEvent""" self.assertTrue(isinstance(event, QueryTerminatedEvent)) self.assertTrue(isinstance(event.id, uuid.UUID)) self.assertTrue(isinstance(event.runId, uuid.UUID)) - # TODO: Needs a test for exception. - self.assertEquals(event.exception, None) - self.assertEquals(event.errorClassOnException, None) + if exception: + self.assertTrue(exception in event.exception) + else: + self.assertEquals(event.exception, None) + + if error_class: + self.assertTrue(error_class in event.errorClassOnException) + else: + self.assertEquals(event.errorClassOnException, None) def check_streaming_query_progress(self, progress): """Check StreamingQueryProgress""" @@ -191,13 +217,15 @@ def check_streaming_query_progress(self, progress): ) self.assertTrue(all(map(lambda v: isinstance(v, int), progress.durationMs.values()))) - self.assertEquals(progress.eventTime, {}) + self.assertTrue(all(map(lambda v: isinstance(v, str), progress.eventTime.values()))) self.assertTrue(isinstance(progress.stateOperators, list)) + self.assertTrue(len(progress.stateOperators) >= 1) for so in progress.stateOperators: self.check_state_operator_progress(so) self.assertTrue(isinstance(progress.sources, list)) + self.assertTrue(len(progress.sources) >= 1) for so in progress.sources: self.check_source_progress(so) @@ -299,6 +327,179 @@ def onQueryTerminated(self, event): self.spark.streams.removeListener(test_listener) self.assertEqual(num_listeners, len(self.spark.streams._jsqm.listListeners())) + def test_query_started_event_fromJson(self): + start_event = """ + { + "id" : "78923ec2-8f4d-4266-876e-1f50cf3c283b", + "runId" : "55a95d45-e932-4e08-9caa-0a8ecd9391e8", + "name" : null, + "timestamp" : "2023-06-09T18:13:29.741Z" + } + """ + start_event = QueryStartedEvent.fromJson(json.loads(start_event)) + self.check_start_event(start_event) + self.assertEqual(start_event.id, uuid.UUID("78923ec2-8f4d-4266-876e-1f50cf3c283b")) + self.assertEqual(start_event.runId, uuid.UUID("55a95d45-e932-4e08-9caa-0a8ecd9391e8")) + self.assertIsNone(start_event.name) + self.assertEqual(start_event.timestamp, "2023-06-09T18:13:29.741Z") + + def test_query_terminated_event_fromJson(self): + terminated_json = """ + { + "id" : "78923ec2-8f4d-4266-876e-1f50cf3c283b", + "runId" : "55a95d45-e932-4e08-9caa-0a8ecd9391e8", + "exception" : "org.apache.spark.SparkException: Job aborted due to stage failure", + "errorClassOnException" : null} + """ + terminated_event = QueryTerminatedEvent.fromJson(json.loads(terminated_json)) + self.check_terminated_event(terminated_event, "SparkException") + self.assertEqual(terminated_event.id, uuid.UUID("78923ec2-8f4d-4266-876e-1f50cf3c283b")) + self.assertEqual(terminated_event.runId, uuid.UUID("55a95d45-e932-4e08-9caa-0a8ecd9391e8")) + self.assertIn("SparkException", terminated_event.exception) + self.assertIsNone(terminated_event.errorClassOnException) + + def test_streaming_query_progress_fromJson(self): + progress_json = """ + { + "id" : "00000000-0000-0001-0000-000000000001", + "runId" : "00000000-0000-0001-0000-000000000002", + "name" : "test", + "timestamp" : "2016-12-05T20:54:20.827Z", + "batchId" : 2, + "numInputRows" : 678, + "inputRowsPerSecond" : 10.0, + "processedRowsPerSecond" : 5.4, + "batchDuration": 5, + "durationMs" : { + "getBatch" : 0 + }, + "eventTime" : { + "min" : "2016-12-05T20:54:20.827Z", + "avg" : "2016-12-05T20:54:20.827Z", + "watermark" : "2016-12-05T20:54:20.827Z", + "max" : "2016-12-05T20:54:20.827Z" + }, + "stateOperators" : [ { + "operatorName" : "op1", + "numRowsTotal" : 0, + "numRowsUpdated" : 1, + "allUpdatesTimeMs" : 1, + "numRowsRemoved" : 2, + "allRemovalsTimeMs" : 34, + "commitTimeMs" : 23, + "memoryUsedBytes" : 3, + "numRowsDroppedByWatermark" : 0, + "numShufflePartitions" : 2, + "numStateStoreInstances" : 2, + "customMetrics" : { + "loadedMapCacheHitCount" : 1, + "loadedMapCacheMissCount" : 0, + "stateOnCurrentVersionSizeBytes" : 2 + } + } ], + "sources" : [ { + "description" : "source", + "startOffset" : 123, + "endOffset" : 456, + "latestOffset" : 789, + "numInputRows" : 678, + "inputRowsPerSecond" : 10.0, + "processedRowsPerSecond" : 5.4, + "metrics": {} + } ], + "sink" : { + "description" : "sink", + "numOutputRows" : -1, + "metrics": {} + }, + "observedMetrics" : { + "event1" : { + "c1" : 1, + "c2" : 3.0 + }, + "event2" : { + "rc" : 1, + "min_q" : "hello", + "max_q" : "world" + } + } + } + """ + progress = StreamingQueryProgress.fromJson(json.loads(progress_json)) + + self.check_streaming_query_progress(progress) + + # checks for progress + self.assertEqual(progress.id, uuid.UUID("00000000-0000-0001-0000-000000000001")) + self.assertEqual(progress.runId, uuid.UUID("00000000-0000-0001-0000-000000000002")) + self.assertEqual(progress.name, "test") + self.assertEqual(progress.timestamp, "2016-12-05T20:54:20.827Z") + self.assertEqual(progress.batchId, 2) + self.assertEqual(progress.numInputRows, 678) + self.assertEqual(progress.inputRowsPerSecond, 10.0) + self.assertEqual(progress.batchDuration, 5) + self.assertEqual(progress.durationMs, {"getBatch": 0}) + self.assertEqual( + progress.eventTime, + { + "min": "2016-12-05T20:54:20.827Z", + "avg": "2016-12-05T20:54:20.827Z", + "watermark": "2016-12-05T20:54:20.827Z", + "max": "2016-12-05T20:54:20.827Z", + }, + ) + self.assertEqual( + progress.observedMetrics, + { + "event1": Row("c1", "c2")(1, 3.0), + "event2": Row("rc", "min_q", "max_q")(1, "hello", "world"), + }, + ) + + # Check stateOperators list + self.assertEqual(len(progress.stateOperators), 1) + state_operator = progress.stateOperators[0] + self.assertTrue(isinstance(state_operator, StateOperatorProgress)) + self.assertEqual(state_operator.operatorName, "op1") + self.assertEqual(state_operator.numRowsTotal, 0) + self.assertEqual(state_operator.numRowsUpdated, 1) + self.assertEqual(state_operator.allUpdatesTimeMs, 1) + self.assertEqual(state_operator.numRowsRemoved, 2) + self.assertEqual(state_operator.allRemovalsTimeMs, 34) + self.assertEqual(state_operator.commitTimeMs, 23) + self.assertEqual(state_operator.memoryUsedBytes, 3) + self.assertEqual(state_operator.numRowsDroppedByWatermark, 0) + self.assertEqual(state_operator.numShufflePartitions, 2) + self.assertEqual(state_operator.numStateStoreInstances, 2) + self.assertEqual( + state_operator.customMetrics, + { + "loadedMapCacheHitCount": 1, + "loadedMapCacheMissCount": 0, + "stateOnCurrentVersionSizeBytes": 2, + }, + ) + + # Check sources list + self.assertEqual(len(progress.sources), 1) + source = progress.sources[0] + self.assertTrue(isinstance(source, SourceProgress)) + self.assertEqual(source.description, "source") + self.assertEqual(source.startOffset, "123") + self.assertEqual(source.endOffset, "456") + self.assertEqual(source.latestOffset, "789") + self.assertEqual(source.numInputRows, 678) + self.assertEqual(source.inputRowsPerSecond, 10.0) + self.assertEqual(source.processedRowsPerSecond, 5.4) + self.assertEqual(source.metrics, {}) + + # Check sink + sink = progress.sink + self.assertTrue(isinstance(sink, SinkProgress)) + self.assertEqual(sink.description, "sink") + self.assertEqual(sink.numOutputRows, -1) + self.assertEqual(sink.metrics, {}) + if __name__ == "__main__": import unittest diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala index 61a0ef1b98e54..5c0027895cda6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala @@ -19,6 +19,11 @@ package org.apache.spark.sql.streaming import java.util.UUID +import org.json4s.{JObject, JString} +import org.json4s.JsonAST.JValue +import org.json4s.JsonDSL.{jobject2assoc, pair2Assoc} +import org.json4s.jackson.JsonMethods.{compact, render} + import org.apache.spark.annotation.Evolving import org.apache.spark.scheduler.SparkListenerEvent @@ -123,7 +128,17 @@ object StreamingQueryListener { val id: UUID, val runId: UUID, val name: String, - val timestamp: String) extends Event + val timestamp: String) extends Event { + + def json: String = compact(render(jsonValue)) + + private def jsonValue: JValue = { + ("id" -> JString(id.toString)) ~ + ("runId" -> JString(runId.toString)) ~ + ("name" -> JString(name)) ~ + ("timestamp" -> JString(timestamp)) + } + } /** * Event representing any progress updates in a query. @@ -131,7 +146,12 @@ object StreamingQueryListener { * @since 2.1.0 */ @Evolving - class QueryProgressEvent private[sql](val progress: StreamingQueryProgress) extends Event + class QueryProgressEvent private[sql](val progress: StreamingQueryProgress) extends Event { + + def json: String = compact(render(jsonValue)) + + private def jsonValue: JValue = JObject("progress" -> progress.jsonValue) + } /** * Event representing that query is idle and waiting for new data to process. @@ -145,7 +165,16 @@ object StreamingQueryListener { class QueryIdleEvent private[sql]( val id: UUID, val runId: UUID, - val timestamp: String) extends Event + val timestamp: String) extends Event { + + def json: String = compact(render(jsonValue)) + + private def jsonValue: JValue = { + ("id" -> JString(id.toString)) ~ + ("runId" -> JString(runId.toString)) ~ + ("timestamp" -> JString(timestamp)) + } + } /** * Event representing that termination of a query. @@ -171,5 +200,14 @@ object StreamingQueryListener { def this(id: UUID, runId: UUID, exception: Option[String]) = { this(id, runId, exception, None) } + + def json: String = compact(render(jsonValue)) + + private def jsonValue: JValue = { + ("id" -> JString(id.toString)) ~ + ("runId" -> JString(runId.toString)) ~ + ("exception" -> JString(exception.orNull)) ~ + ("errorClassOnException" -> JString(errorClassOnException.orNull)) + } } }